## Prior-Guided Neural Architecture Search 

Competition Homepage: [2022 CVPR Track1: SuperNet Track](https://aistudio.baidu.com/aistudio/competition/detail/149/0/introduction)

Table of Content

- Environment Setup
- SuperNet Training
- SuperNet Evaluation 

### 1. Environment Setup


1.1 Env Requirements

```
paddle==2.2.2
fire
matplotlib
visualdl 
```

You can run commands:

```bash
pip install -r requirements.txt 
```

1.2 Data Prepare 

You can download imagenet-mini [from kaggle](https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000). 

ImageNet-mini is a toy dataset compared with imagenet-1k, which has 1000 classes and is only 4.24GB.

For quick iteration, we adopted imagenet-mini as a dataset for training and evaluation. If you have enough computation resources, you can try imagine-1k directly.

In PP AI Studio, [imagenet-mini datasets](https://aistudio.baidu.com/aistudio/datasetdetail/89857) have been offered to us. 


1.3 Arch Json Prepare 

[Here](https://aistudio.baidu.com/aistudio/datasetdetail/134077) you can get the 4,5000 architectures for test. Download CVPR_2022_NAS_Track1_test.json and put it to `checkpoints/CVPR_2022_NAS_Track1_test.json` 


## 2. SuperNet Training

In this part, we will introduce the overall flow of training.

(1) Import necessary modules.

In [None]:
import os

import paddle
import paddle.nn as nn
from paddle.nn import CrossEntropyLoss
from paddle.vision.transforms import (
    RandomHorizontalFlip, RandomResizedCrop, SaturationTransform, 
    Compose, Resize, HueTransform, BrightnessTransform, ContrastTransform, 
    RandomCrop, Normalize, RandomRotation, CenterCrop)
from paddle.io import DataLoader
from paddle.optimizer.lr import CosineAnnealingDecay, MultiStepDecay, LinearWarmup

from hnas.utils.callbacks import LRSchedulerM, MyModelCheckpoint
from hnas.utils.transforms import ToArray
from hnas.dataset.random_size_crop import MyRandomResizedCrop
from paddle.vision.datasets import DatasetFolder

from paddleslim.nas.ofa.convert_super import Convert, supernet
from paddleslim.nas.ofa import RunConfig, DistillConfig, ResOFA
from paddleslim.nas.ofa.utils import utils

import paddle.distributed as dist
from hnas.utils.yacs import CfgNode
from hnas.models.builder import build_classifier
from hnas.utils.hapi_wrapper import Trainer

: 

(2) Set the loss function and accuracy tools.

We offer three ways to compute loss:

- Normal CrossEntropy Loss for teacher network.  
- Inplace Distillation for student network.
- Knowledge Distillation for student network.

In [None]:
def _loss_forward(self, input, tea_input=None, label=None):
    if tea_input is not None and label is not None:
        # knoledge distillation = cross entropy + inplace distillation
        ce = paddle.nn.functional.cross_entropy(
            input,
            label,
            weight=self.weight,
            ignore_index=self.ignore_index,
            reduction=self.reduction,
            soft_label=self.soft_label,
            axis=self.axis,
            name=self.name)

        kd = paddle.nn.functional.cross_entropy(
            input,
            paddle.nn.functional.softmax(tea_input),
            weight=self.weight,
            ignore_index=self.ignore_index,
            reduction=self.reduction,
            soft_label=True,
            axis=self.axis)
        return ce, kd
    elif tea_input is not None and label is None:
        # inplace distillation
        kd = paddle.nn.functional.cross_entropy(
            input,
            paddle.nn.functional.softmax(tea_input),
            weight=self.weight,
            ignore_index=self.ignore_index,
            reduction=self.reduction,
            soft_label=True,
            axis=self.axis)
        return kd 
    elif label is not None:
        # normal cross entropy 
        ce = paddle.nn.functional.cross_entropy(
            input,
            label,
            weight=self.weight,
            ignore_index=self.ignore_index,
            reduction=self.reduction,
            soft_label=False,
            axis=self.axis,
            name=self.name)
        return ce
    else:
        raise "Not Implemented Loss."

CrossEntropyLoss.forward = _loss_forward

def _compute(self, pred, tea_pred, label=None, *args):
    if label is None:
        label = tea_pred
    pred = paddle.argsort(pred, descending=True)
    pred = paddle.slice(
        pred, axes=[len(pred.shape) - 1], starts=[0], ends=[self.maxk])
    if (len(label.shape) == 1) or \
        (len(label.shape) == 2 and label.shape[-1] == 1):
        label = paddle.reshape(label, (-1, 1))
    elif label.shape[-1] != 1:
        label = paddle.argmax(label, axis=-1, keepdim=True)
    correct = pred == label
    return paddle.cast(correct, dtype='float32')

paddle.metric.Accuracy.compute = _compute

(3) Hyperparameter Settings.

In [None]:
backbone='resnet48_prelu'
image_size='224'
max_epoch=70
lr=0.001
weight_decay=0.
momentum=0.9
batch_size=256
dyna_batch_size=4
warmup=2
phase=None
resume=None
pretrained='checkpoints/resnet48.pdparams'
image_dir='/root/paddlejob/workspace/env_run/data/ILSVRC2012/'
save_dir='checkpoints/res48-depth'
save_freq=20
log_freq=100
visualdl_dir="./visualdl_log/autoslim3"

(4) Run and Main function.

In [None]:
def run(
    backbone='resnet48',
    image_size='224',
    max_epoch=120,
    lr=0.0025,
    weight_decay=0.,
    momentum=0.9,
    batch_size=80,
    dyna_batch_size=4,
    warmup=2,
    phase=None,
    resume=None,
    pretrained='checkpoints/resnet48.pdparams',
    image_dir='/root/paddlejob/workspace/env_run/data/ILSVRC2012/',
    save_dir='checkpoints/res48-depth',
    save_freq=20,
    log_freq=100,
    visualdl_dir="./visualdl_log/autoslim3",
    **kwargs
    ):
    run_config = locals()
    run_config.update(run_config["kwargs"])
    del run_config["kwargs"]
    config = CfgNode(run_config)
    config.image_size_list = [int(x) for x in config.image_size.split(',')]

    nprocs = len(paddle.get_cuda_rng_state())
    gpu_str = []
    for x in range(nprocs):
        gpu_str.append(str(x))
    gpu_str = ','.join(gpu_str)
    print(f'gpu num: {nprocs}')
    dist.spawn(main, args=(config,), nprocs=nprocs, gpus=gpu_str)


def main(cfg):
    paddle.set_device('gpu:{}'.format(dist.ParallelEnv().device_id))
    if dist.get_rank() == 0:
        print(cfg)
    IMAGE_MEAN = (0.485,0.456,0.406)
    IMAGE_STD = (0.229,0.224,0.225)

    cfg.lr = cfg.lr * cfg.batch_size * dist.get_world_size() / 256
    warmup_step = int(1281024 / (cfg.batch_size * dist.get_world_size())) * cfg.warmup

    # data augmentation 
    transforms = Compose([
        MyRandomResizedCrop(cfg.image_size_list),
        RandomHorizontalFlip(),
        ToArray(),
        Normalize(IMAGE_MEAN, IMAGE_STD),
    ])
    train_set = DatasetFolder(os.path.join(cfg.image_dir, 'train'), transform=transforms)
    callbacks = [LRSchedulerM(), 
                 MyModelCheckpoint(cfg.save_freq, cfg.save_dir, cfg.resume, cfg.phase),
                 paddle.callbacks.VisualDL(log_dir=cfg.visualdl_dir)]

    # build resnet48 and teacher net
    net = build_classifier(cfg.backbone, pretrained=cfg.pretrained, reorder=True)
    tnet = build_classifier(cfg.backbone, pretrained=cfg.pretrained, reorder=False)
    origin_weights = {}
    for name, param in net.named_parameters():
        origin_weights[name] = param
    
    # convert resnet48 to supernet 
    sp_model = Convert(supernet(expand_ratio=[1.0])).convert(net)  # net转换成supernet
    utils.set_state_dict(sp_model, origin_weights)  # 重新对supernet加载数据
    del origin_weights

    # set candidate config 
    cand_cfg = {
            'i': [224],  # image size
            'd': [(2, 5), (2, 5), (2, 8), (2, 5)],  # depth
            'k': [3],  # kernel size
            'c': [1.0, 0.95, 0.9, 0.85, 0.8, 0.75, 0.7] # channel ratio
    }

    default_distill_config = {
        'lambda_distill': 0.5,
        'teacher_model': tnet,
        'mapping_layers': None,
        'teacher_model_path': None,
        'distill_fn': None,
        'mapping_op': 'conv2d'
    }

    ofa_net = ResOFA(sp_model,
                     distill_config=DistillConfig(**default_distill_config), 
                     candidate_config=cand_cfg,
                     block_conv_num=2)

    # ofa_net.set_task(['depth', 'expand_ratio'])
    ofa_net.set_task('expand_ratio')

    run_config = {'dynamic_batch_size': cfg.dyna_batch_size}
    model = Trainer(ofa_net, cfg=run_config)

    # calculate loss by ce 
    model.prepare(
        paddle.optimizer.Momentum(
            learning_rate=LinearWarmup( # delete cfg.lr * 0.05 
                CosineAnnealingDecay(cfg.lr, cfg.max_epoch), warmup_step, 0., cfg.lr),
            momentum=cfg.momentum,
            parameters=model.parameters(),
            weight_decay=cfg.weight_decay),
        CrossEntropyLoss(),
        paddle.metric.Accuracy(topk=(1,5)))
        
    model.fit(
        train_set,
        None,
        epochs=cfg.max_epoch,
        batch_size=cfg.batch_size,
        save_dir=cfg.save_dir,
        save_freq=cfg.save_freq,
        log_freq=cfg.log_freq,
        shuffle=True,
        num_workers=8,
        verbose=2, 
        drop_last=True,
        callbacks=callbacks,
    )

(5) Start to run.

In [None]:
run(backbone=backbone, 
    image_size=image_size,
    max_epoch=max_epoch,
    lr=lr,
    weight_decay=weight_decay,
    momentum=momentum,
    batch_size=batch_size,
    dyna_batch_size=dyna_batch_size,
    warmup=warmup,
    pretrained=pretrained,
    image_dir=image_dir,
    save_dir=save_dir,
    visualdl_dir=visualdl_dir,
    )

The previous section did not cover the core implementation. In the next section we will describe how `RANK LOSS` is built.

Rank Loss Function: 

In [None]:
import paddle.nn as nn
import paddle.nn.functional as F

class PairwiseRankLoss(nn.Layer):
    """pairwise ranking loss for rank consistency 

    Args:
        prior1 (float | int): the prior value of arch1 
        prior2 (float | int): the prior value of arch2 
        loss1: the batch loss of arch1 
        loss2: the batch loss of arch2 
    """

    def forward(self, prior1, prior2, loss1, loss2, coeff=1.):
        return coeff * F.relu(loss2-loss1.detach()) if prior1 < prior2 else coeff * F.relu(loss1.detach()-loss2)

Integrate Pairwise `Rank Loss` into the `Sandwich Rule`. For example, take flops as prior.

```python
def train_batch_sandwich_with_rank(self, inputs, labels=None, **kwargs):
    assert self.model._optimizer, "model not ready, please call `model.prepare()` first"
    self.model.network.model.train() # set network to training mode.
    self.mode = 'train'

    inputs = to_list(inputs)
    self._input_info = _update_input_info(inputs)
    labels = to_variable(labels).squeeze(0)
    epoch = kwargs.get('epoch', None)
    self.epoch = epoch
    nBatch = kwargs.get('nBatch', None)
    step = kwargs.get('step', None)

    subnet_seed = int('%d%.1d' % (epoch * nBatch + step, step)) # set seed 
    np.random.seed(subnet_seed)

    ######### Sandwich Rule ##############

    # sample largest subnet as teacher network
    largest_config = self.model.network.active_autoslim_subnet(sample_type="largest")
    self.model.network.set_net_config(largest_config)
    if self._nranks > 1:
        teacher_output = self.ddp_model.forward(*[to_variable(x) for x in inputs])
    else:
        teacher_output = self.model.network.forward(*[to_variable(x) for x in inputs])
    # normal forward with CrossEntropy. 
    loss1 = self.model._loss(input=teacher_output[0], tea_input=None, label=labels)
    loss1.backward()

    # sample smallest subnet as student network and perform distill operation
    smallest_config = self.model.network.active_autoslim_subnet(sample_type="smallest")
    self.model.network.set_net_config(smallest_config)
    if self._nranks > 1:
        output = self.ddp_model.forward(*[to_variable(x) for x in inputs])
    else:
        output = self.model.network.forward(*[to_variable(x) for x in inputs])
    # forward with inplace distillation
    loss2 = self.model._loss(input=output[0],tea_input=teacher_output[0], label=None)
    loss2.backward()
    del output 

    # sample random subnets as student net and perform distill operation
    for _ in range(self.dyna_bs-2): 
        random_config1 = self.model.network.active_autoslim_subnet(sample_type="random")
        self.model.network.set_net_config(random_config1)
        if self._nranks > 1:
            output = self.ddp_model.forward(*[to_variable(x) for x in inputs])
        else:
            output = self.model.network.forward(*[to_variable(x) for x in inputs])
        # forward with inplace distillation
        loss = self.model._loss(input=output[0],tea_input=teacher_output[0], label=None)
        loss.backward()
        del output

    # pairwise rank loss 
    for _ in range(self.dyna_bs-2):
        # generate two random arch to compute pairwise rank loss.
        random_config1 = self.model.network.active_autoslim_subnet(sample_type="random")
        self.model.network.set_net_config(random_config1)
        flops1 = get_arch_flops(self.model.network.gen_subnet_code)
        if self._nranks > 1:
            output = self.ddp_model.forward(*[to_variable(x) for x in inputs])
        else:
            output = self.model.network.forward(*[to_variable(x) for x in inputs])
        loss3 = self.model._loss(input=output[0],tea_input=None, label=labels)
        del output

        random_config2 = self.model.network.active_autoslim_subnet(sample_type="random")
        self.model.network.set_net_config(random_config2)
        flops2 = get_arch_flops(self.model.network.gen_subnet_code)
        if self._nranks > 1:
            output = self.ddp_model.forward(*[to_variable(x) for x in inputs])
        else:
            output = self.model.network.forward(*[to_variable(x) for x in inputs])
        loss4 = self.model._loss(input=output[0],tea_input=None, label=labels)
        
        # gradually increase rank loss weights.
        loss5 =  min(2, epoch/10.) * self.pairwise_rankloss(flops1, flops2, loss3, loss4)
        loss5.backward()

    self.model._optimizer.step()
    self.model._optimizer.clear_grad()

    metrics = []
    for metric in self.model._metrics:
        metric_outs = metric.compute(output[0], labels)
        m = metric.update(*[to_numpy(m) for m in to_list(metric_outs)])
        metrics.append(m)

    return ([to_numpy(l) for l in [loss1]], metrics) if len(metrics) > 0 else [to_numpy(l) for l in [loss1]]
```

## 3. SubNet Evaluation

For evaluation, 

In [None]:
import os

import paddle
import paddle.distributed as dist
from paddle.nn import CrossEntropyLoss
from paddle.optimizer.lr import CosineAnnealingDecay, LinearWarmup
from paddle.vision.datasets import DatasetFolder
from paddle.vision.transforms import CenterCrop, Compose, Normalize, Resize

from hnas.models.builder import build_classifier
from hnas.utils.callbacks import EvalCheckpoint
from hnas.utils.transforms import ToArray
from hnas.utils.yacs import CfgNode
from paddleslim.nas.ofa import DistillConfig, ResOFA
from paddleslim.nas.ofa.convert_super import Convert, supernet
from paddleslim.nas.ofa.utils import utils

In [None]:
backbone='resnet48_prelu'
image_size='224'
max_epoch=70
lr=0.001
weight_decay=0.
momentum=0.9
batch_size=256
dyna_batch_size=4
warmup=2
phase=None
resume=None
pretrained='checkpoints/resnet48.pdparams'
image_dir='/root/paddlejob/workspace/env_run/data/ILSVRC2012/'
save_dir='checkpoints/res48-depth'
save_freq=20
log_freq=100
visualdl_dir="./visualdl_log/autoslim3"

In [None]:
def run(
    backbone='resnet48',
    image_size='224',
    max_epoch=120,
    lr=0.0025,
    weight_decay=3e-5,
    momentum=0.9,
    batch_size=80,
    dyna_batch_size=4,
    warmup=2,
    phase=None,
    resume=None,
    pretrained='checkpoints/resnet48.pdparams',
    image_dir='/root/paddlejob/workspace/env_run/data/ILSVRC2012/',
    save_dir='checkpoints/res48-depth',
    save_freq=5,
    log_freq=100,
    json_path=None,
    **kwargs
    ):
    run_config = locals()
    run_config.update(run_config["kwargs"])
    del run_config["kwargs"]
    config = CfgNode(run_config)
    config.image_size_list = [int(x) for x in config.image_size.split(',')]

    nprocs = len(paddle.get_cuda_rng_state())
    gpu_str = []
    for x in range(nprocs):
        gpu_str.append(str(x))
    gpu_str = ','.join(gpu_str)
    print(f'gpu num: {nprocs}')
    # dist.spawn(main, args=(config,), nprocs=nprocs, gpus=gpu_str)
    main(config)


def main(cfg):
    paddle.set_device('gpu:{}'.format(dist.ParallelEnv().device_id))
    if dist.get_rank() == 0:
        print(cfg)
    IMAGE_MEAN = (0.485,0.456,0.406)
    IMAGE_STD = (0.229,0.224,0.225)

    cfg.lr = cfg.lr * cfg.batch_size * dist.get_world_size() / 256
    warmup_step = int(1281024 / (cfg.batch_size * dist.get_world_size())) * cfg.warmup

    val_transforms = Compose([Resize(256), CenterCrop(224), ToArray(), Normalize(IMAGE_MEAN, IMAGE_STD)])
    val_set = DatasetFolder(os.path.join(cfg.image_dir, 'val'), transform=val_transforms)
    # val_set = HDF5DatasetFolder("/data/home/scv6681/run/data/hdf5/imagenetmini_val.h5", transform=val_transforms)

    eval_callbacks = [EvalCheckpoint('{}/final'.format(cfg.save_dir))]

    net = build_classifier(cfg.backbone, pretrained=cfg.pretrained, reorder=True)
    tnet = build_classifier(cfg.backbone, pretrained=cfg.pretrained, reorder=False)
    origin_weights = {}
    for name, param in net.named_parameters():
        origin_weights[name] = param
    
    sp_model = Convert(supernet(expand_ratio=[1.0])).convert(net)  # net转换成supernet
    utils.set_state_dict(sp_model, origin_weights)  # 重新对supernet加载数据
    del origin_weights

    cand_cfg = {
            'i': [224],  # image size
            'd': [(2, 5), (2, 5), (2, 8), (2, 5)],  # depth
            'k': [3],  # kernel size
            'c': [1.0, 0.95, 0.9, 0.85, 0.8, 0.75, 0.7] # channel ratio
    }
    ofa_net = ResOFA(sp_model,
                     distill_config=DistillConfig(teacher_model=tnet), 
                     candidate_config=cand_cfg,
                     block_conv_num=2)
    ofa_net.set_task('expand_ratio')

    run_config = {'dynamic_batch_size': cfg.dyna_batch_size}
    model = Trainer(ofa_net, cfg=run_config)
    model.prepare(
        paddle.optimizer.Momentum(
            learning_rate=LinearWarmup(
                CosineAnnealingDecay(cfg.lr, cfg.max_epoch), warmup_step, 0., cfg.lr),
            momentum=cfg.momentum,
            parameters=model.parameters(),
            weight_decay=cfg.weight_decay),
        CrossEntropyLoss(),
        paddle.metric.Accuracy(topk=(1,5)))

    model.evaluate_whole_test(val_set, batch_size=cfg.batch_size, num_workers=8, callbacks=eval_callbacks, json_path=cfg.json_path)

In [None]:
run(backbone=backbone, 
    image_size=image_size,
    max_epoch=max_epoch,
    lr=lr,
    weight_decay=weight_decay,
    momentum=momentum,
    batch_size=batch_size,
    dyna_batch_size=dyna_batch_size,
    warmup=warmup,
    pretrained=pretrained,
    image_dir=image_dir,
    save_dir=save_dir,
    visualdl_dir=visualdl_dir,
    )


```python
def evaluate_whole_test(
        self,
        eval_data,
        batch_size=256,
        log_freq=10,
        verbose=1,
        num_workers=4,
        callbacks=None,
        json_path=None):

    candidate_path = json_path 

    with open(candidate_path, "r") as f:
        candidate_dict = json.load(f)
        save_candidate = candidate_dict.copy()

    if eval_data is not None and isinstance(eval_data, Dataset):
        eval_sampler = None 
        eval_loader = DataLoader(
            eval_data, 
            batch_sampler=eval_sampler,
            places=self._place,
            shuffle=False, 
            num_workers=num_workers,
            batch_size=batch_size, 
            return_list=True, 
            use_shared_memory=True,
            use_buffer_reader=True)
    else:
        eval_loader = eval_data

    self._test_dataloader = eval_loader

    cbks = config_callbacks(
        callbacks,
        model=self,
        log_freq=log_freq,
        verbose=verbose,
        metrics=self._metrics_name(), )

    eval_steps = self._len_data_loader(eval_loader)

    self.network.model.eval()

    import time
    show_flag = True

    sample_result = []
    for arch_name, config in candidate_dict.items():
        s1 = time.time() 
        cbks.on_begin('eval', {'steps': eval_steps, 'metrics': self._metrics_name()})

        self.network.active_specific_subnet(224, config['arch'])

        logs = self._run_one_epoch(eval_loader, cbks, 'eval')
        
        s3 = time.time()
        if ParallelEnv().local_rank == 0 and show_flag:
            print("forward_one_epoch time: ", s3-s1)

        cbks.on_end('eval', logs)

        self._test_dataloader = None

        eval_result = {}
        for k in self._metrics_name():
            eval_result[k] = logs[k]
        sample_res = '{} {} {} {}'.format(arch_name, config['arch'], eval_result['acc_top1'], eval_result['acc_top5'])
        if ParallelEnv().local_rank == 0:
            print(sample_res)

        sample_result.append(sample_res)

        if ParallelEnv().local_rank == 0:
            num = json_path.split('_')[-1].split(".")[0]
            with open(f'checkpoints/results/19th_rkloss_mish_flops_latedecay_sandwich_2times/channel_sample_{num}.txt', 'a') as f:
                f.write('{}\n'.format(sample_res))

        save_candidate[arch_name]['acc'] = eval_result['acc_top1']

    if ParallelEnv().local_rank == 0:
        save_path = candidate_path.replace('CVPR_2022_NAS_Track1_test', 'CVPR_2022_NAS_Track1_test_{}'.format(time.strftime("%Y_%m_%d__%H_%M_%S", time.localtime())))
        with open(save_path, 'w') as f:
            json.dump(save_candidate, f)

    return sample_result

```