In [1]:
import torch
from torch import nn
import numpy as np
from copy import deepcopy
from tensorboardX import SummaryWriter
import argparse
from collections import OrderedDict

from general_functions.prune_utils import BN_preprocess, pack_filter_para, generate_searchspace, load_weights_from_loose_model
from general_functions.utils import parse_data_config, get_logger, create_directories_from_list
from building_blocks.builder import SampledNet, Loss
from building_blocks.modeldef import MODEL_ARCH, Test_model_arch

from supernet_main_file import _create_data_loader, _create_test_data_loader

from architecture_functions.training_functions import TrainerArch
from architecture_functions.config_for_arch import CONFIG_ARCH
from supernet_functions.config_for_supernet import CONFIG_SUPERNET
from supernet_functions.lookup_table_builder import (LookUpTable, SEARCH_SPACE_BACKBONE, SEARCH_SPACE_HEAD,                                                                      SEARCH_SPACE_FPN, YOLO_LAYER_26, YOLO_LAYER_13, extract_anchors)
from supernet_functions.model_supernet import YOLOLayer, Stochastic_SuperNet
from supernet_functions.supernet_prune import PrunedModel


# parser = argparse.ArgumentParser("architecture")
# parser.add_argument('--architecture_name', type=str, default='test_structure',
#                     help='You can choose architecture from the building_blocks/modeldef.py')
# parser.add_argument("-d", "--data", type=str, default="./config/detrac.data",
#                     help="Path to data config file (.data)")
# parser.add_argument("--n_cpu", type=int, default=16,
#                     help="Number of cpu threads to use during batch generation")
# args = parser.parse_args([])


def makeYOLOLayer(yolo_layer_26, yolo_layer_13):
    anchor26 = extract_anchors(yolo_layer_26)
    anchor13 = extract_anchors(yolo_layer_13)
    num_cls = yolo_layer_26['classes']
    return YOLOLayer(anchor26, num_cls), YOLOLayer(anchor13, num_cls)


def get_model(arch):
    #assert arch in MODEL_ARCH
    #arch_def = MODEL_ARCH[arch]

    # for test only
    ##############################
    assert arch in Test_model_arch
    arch_def = Test_model_arch[arch]
    ##############################
    yolo_layer26, yolo_layer13 = makeYOLOLayer(YOLO_LAYER_26, YOLO_LAYER_13)
    layer_parameters, _ = LookUpTable._generate_layers_parameters(search_space=SEARCH_SPACE_BACKBONE)
    layer_parameters_head, _ = LookUpTable._generate_layers_parameters(search_space=SEARCH_SPACE_HEAD)
    layer_parameters_fpn, _ = LookUpTable._generate_layers_parameters(search_space=SEARCH_SPACE_FPN)

    model = SampledNet(arch_def, num_anchors=len(YOLO_LAYER_26['mask']), num_cls=YOLO_LAYER_26['classes'],
                       layer_parameters=layer_parameters,
                       layer_parameters_head26=layer_parameters_head,
                       layer_parameters_head13=layer_parameters_head,
                       layer_parameters_fpn=layer_parameters_fpn,
                       yolo_layer26=yolo_layer26,
                       yolo_layer13=yolo_layer13)
    return model

def get_pruned_model(arch, num_filters):
    # assert arch in MODEL_ARCH
    # arch_def = MODEL_ARCH[arch]
    
    # for test only
    ##############################
    assert arch in Test_model_arch
    arch_def = Test_model_arch[arch]
    ##############################
    yolo_layer26, yolo_layer13 = makeYOLOLayer(YOLO_LAYER_26, YOLO_LAYER_13)
    backbone_para = pack_filter_para(1, 11, num_filters)
    fpn_para = pack_filter_para(15, 22, num_filters)
    head26_para = pack_filter_para(25, 29, num_filters)
    head13_para = pack_filter_para(30, 34, num_filters)
    
    # print(backbone_para)
    # print(fpn_para)
    # print(head26_para)
    # print(head13_para)
    
    input_shape_backbone, channel_size_backbone, prune_backbone = generate_searchspace(backbone_para, first_input=16)
    PRUNED_SEARCH_SPACE_BACKBONE = OrderedDict()
    PRUNED_SEARCH_SPACE_BACKBONE['input_shape'] = input_shape_backbone
    PRUNED_SEARCH_SPACE_BACKBONE['channel_size'] = channel_size_backbone
    PRUNED_SEARCH_SPACE_BACKBONE['prune'] = prune_backbone
    PRUNED_SEARCH_SPACE_BACKBONE['strides'] = SEARCH_SPACE_BACKBONE['strides']

    input_fpn = pack_filter_para(14, 14, num_filters)[0][0]
    input_shape_fpn, channel_size_fpn, prune_fpn = generate_searchspace(fpn_para, first_input=input_fpn)
    PRUNED_SEARCH_SPACE_FPN = OrderedDict()
    PRUNED_SEARCH_SPACE_FPN['input_shape'] = input_shape_fpn
    PRUNED_SEARCH_SPACE_FPN['channel_size'] = channel_size_fpn
    PRUNED_SEARCH_SPACE_FPN['prune'] = prune_fpn
    PRUNED_SEARCH_SPACE_FPN['strides'] = SEARCH_SPACE_FPN['strides']

    input_head26 = pack_filter_para(23, 23, num_filters)[0][0]
    input_head13 = pack_filter_para(24, 24, num_filters)[0][0]
    input_shape_head26, channel_size_head26, prune_head26 = generate_searchspace(head26_para, first_input=input_head26)
    input_shape_head13, channel_size_head13, prune_head13 = generate_searchspace(head13_para, first_input=input_head13)
    PRUNED_SEARCH_SPACE_HEAD26 = OrderedDict()
    PRUNED_SEARCH_SPACE_HEAD13 = OrderedDict()
    PRUNED_SEARCH_SPACE_HEAD26['input_shape'] = input_shape_head26
    PRUNED_SEARCH_SPACE_HEAD13['input_shape'] = input_shape_head13
    PRUNED_SEARCH_SPACE_HEAD26['channel_size'] = channel_size_head26
    PRUNED_SEARCH_SPACE_HEAD13['channel_size'] = channel_size_head13
    PRUNED_SEARCH_SPACE_HEAD26['prune'] = prune_head26
    PRUNED_SEARCH_SPACE_HEAD13['prune'] = prune_head13
    PRUNED_SEARCH_SPACE_HEAD26['strides'] = SEARCH_SPACE_HEAD['strides']
    PRUNED_SEARCH_SPACE_HEAD13['strides'] = SEARCH_SPACE_HEAD['strides']
   
     # print("backbone")
#     print(PRUNED_SEARCH_SPACE_BACKBONE)
#     print("fpn")
#     print(PRUNED_SEARCH_SPACE_FPN)
#     print("head26")
#     print(PRUNED_SEARCH_SPACE_HEAD26)
#     print("head13")
#     print(PRUNED_SEARCH_SPACE_HEAD13)
    
    
    layer_parameters, _ = LookUpTable._generate_layers_parameters(search_space=PRUNED_SEARCH_SPACE_BACKBONE, 
                                                                  prune=True)
    layer_parameters_fpn, _ = LookUpTable._generate_layers_parameters(search_space=PRUNED_SEARCH_SPACE_FPN, 
                                                                      prune=True)
    layer_parameters_head26, _ = LookUpTable._generate_layers_parameters(search_space=PRUNED_SEARCH_SPACE_HEAD26,                                                                                    prune=True)
    layer_parameters_head13, _ = LookUpTable._generate_layers_parameters(search_space=PRUNED_SEARCH_SPACE_HEAD13,                                                                                    prune=True)

    model = SampledNet(arch_def, num_anchors=len(YOLO_LAYER_26['mask']), num_cls=YOLO_LAYER_26['classes'],
                       layer_parameters=layer_parameters,
                       layer_parameters_head26=layer_parameters_head26,
                       layer_parameters_head13=layer_parameters_head13,
                       layer_parameters_fpn=layer_parameters_fpn,
                       yolo_layer26=yolo_layer26,
                       yolo_layer13=yolo_layer13,
                       prune_para=num_filters)
    return model


def main():
    manual_seed = 1
    np.random.seed(manual_seed)
    torch.manual_seed(manual_seed)
    torch.cuda.manual_seed_all(manual_seed)
    torch.backends.cudnn.benchmark = True
    
    create_directories_from_list([CONFIG_ARCH['logging']['path_to_tensorboard_logs']])
    
    logger = get_logger(CONFIG_ARCH['logging']['path_to_log_file'])
    writer = SummaryWriter(log_dir=CONFIG_ARCH['logging']['path_to_tensorboard_logs'])

    # data_config = parse_data_config(args.data)
    # train_path = data_config["train"]
    # valid_path = data_config["valid"]

    #### DataLoading
#     train_loader = _create_data_loader(train_path,
#                                        CONFIG_ARCH['dataloading']['batch_size'],
#                                        CONFIG_ARCH['dataloading']['img_size'],
#                                        args.n_cpu)

#     valid_loader = _create_test_data_loader(valid_path,
#                                             CONFIG_ARCH['dataloading']['batch_size'],
#                                             CONFIG_ARCH['dataloading']['img_size'],
#                                             args.n_cpu)

    #### Model
    arch = 'test_structure'
    sub_model = get_model(arch).cuda()
    
#     input = torch.randn(1, 3, 416, 416).cuda()
#     out = sub_model(input)
    
#     print(out[0].shape)
#     print(out[1].shape)
    
    #### Load Parameters
    lookup_table = LookUpTable()
    supernet = Stochastic_SuperNet(lookup_table=lookup_table)
    checkpoint = torch.load(CONFIG_SUPERNET['train_settings']['path_to_save_model'])
    supernet.load_state_dict(checkpoint["state_dict"])

    del_keys_backbone = []
    rev_keys_backbone = []
    del_keys_fpn = []
    rev_keys_fpn = []
    del_keys_head = []
    rev_keys_head = []

    supernet_copy = supernet.state_dict().copy()

    def key_process(key, start, end, del_keys, rev_keys):
        global chosen_id
        for i in range(start, end):
            if key.split('.')[1] == str(i) and key.split('.')[-1] == 'AP_path_alpha':
                chosen_id = np.argmax(supernet_copy[key].cpu().numpy())
                # print(chosen_id)
            if len(key.split('.')) > 3:
                if key.split('.')[1] == str(i) and key.split('.')[3] != str(chosen_id):
                    del_keys.append(key)
                elif key.split('.')[1] == str(i) and key.split('.')[3] == str(chosen_id):
                    rev_keys.append(key)

    for key in supernet_copy.keys():
        key_process(key, 1, 12, del_keys_backbone, rev_keys_backbone)  # backbone

        key_process(key, 15, 23, del_keys_fpn, rev_keys_fpn)  # fpn

        key_process(key, 25, 35, del_keys_head, rev_keys_head)  # head

    # delete the unchosen parameters
    def key_revise(supernet, del_keys, rev_keys):
        for del_key in del_keys:
            del supernet[del_key]

        # revise the name of rev_keys to match the sub-model
        for k in rev_keys:
            or_name = k.split('.')
            new_names = or_name[0:2] + or_name[4:]
            new_name = ''
            for i in range(len(new_names)):
                new_name += new_names[i]
                if i != len(new_names) - 1:
                    new_name += '.'

            supernet[new_name] = supernet[k]
            del supernet[k]

    key_revise(supernet_copy, del_keys_backbone, rev_keys_backbone)
    key_revise(supernet_copy, del_keys_fpn, rev_keys_fpn)
    key_revise(supernet_copy, del_keys_head, rev_keys_head)

    missing_keys, unexpected_keys = sub_model.load_state_dict(state_dict=supernet_copy, strict=False)
    #print(missing_keys)
    #print(unexpected_keys)
    # save the sub-model
#     input = torch.randn(1, 3, 416, 416).cuda()
#     out = model(input)
    
#     print(out[0].shape)
#     print(out[1].shape)
    torch.save(sub_model.state_dict(), CONFIG_ARCH['sub-model-saving'])
    model = deepcopy(sub_model)
    # 接下来可以剪枝
    ##############
    #第一步先BN同步
    ##############
    BN_preprocess(model)
    #######
    # prune
    #######
    # get highest prune ratio
    highest_thre, percent_limit = PrunedModel.get_highest_thre(model)
    # get the tres and evaluate pruned model
    threshold = PrunedModel.prune_and_eval(model, valid_loader, 
                                           CONFIG_ARCH['dataloading']['img_size'], 
                                           percent=percent_limit-0.001)

    # get num_filters for re-constructing the pruned model
    num_filters, filters_mask = PrunedModel.obtain_filters_mask(model, threshold)
    # rebuild the model
    # generate the new layer para.
    pruned_model = get_pruned_model(arch, num_filters).cuda()
    
    
    # from building_blocks.builder import ConvBNRelu
    # for i, module in enumerate(pruned_model.module_list):
    #     if 1 <= i <= 34:
    #         print("for id:", i)
    #         for m in module.modules():
    #             if isinstance(m, ConvBNRelu):
    #                 print(m[0].weight.data.shape)
    
#     pruned_model.train()
#     Input = torch.randn(100, 3, 416, 416).cuda()
#     out = pruned_model(Input)
#     print(out[0].shape)
#     print(out[1].shape)
        
    # reload parameters
    print("reload the parameters...")
    load_weights_from_loose_model(pruned_model, model, filters_mask)
    print("finish reloading!")
#     model = nn.DataParallel(pruned_model, [0])

    #### Loss and Optimizer
    optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, pruned_model.parameters()),
                                lr=CONFIG_ARCH['optimizer']['lr'],
                                momentum=CONFIG_ARCH['optimizer']['momentum'],
                                weight_decay=CONFIG_ARCH['optimizer']['weight_decay'])
    criterion = Loss().cuda()
    
    #### Scheduler
    if CONFIG_ARCH['train_settings']['scheduler'] == 'MultiStepLR':
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                    milestones=CONFIG_ARCH['train_settings']['milestones'],
                                                    gamma=CONFIG_ARCH['train_settings']['lr_decay'])  
    elif CONFIG_ARCH['train_settings']['scheduler'] == 'CosineAnnealingLR':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                               T_max=CONFIG_ARCH['train_settings']['cnt_epochs'],
                                                               eta_min=0.001, last_epoch=-1)
    else:
        logger.info("Please, specify scheduler in architecture_functions/config_for_arch")
        
    #### Training Loop
    trainer = TrainerArch(criterion, optimizer, scheduler, logger, writer)
    
#     from building_blocks.builder import ConvBNRelu
#     for i, module in enumerate(pruned_model.module_list):
#         if 1 <= i <= 34:
#             print("for id:", i)
#             for m in module.modules():
#                 if isinstance(m, ConvBNRelu):
#                     print(m)
    
    
    pruned_model.train()
    Input = torch.randn(100, 3, 416, 416).cuda()
    out = pruned_model(Input)
    print(out[0].shape)
    print(out[1].shape)
    
    # trainer.train_loop(train_loader, valid_loader, pruned_model) 
    
if __name__ == "__main__":
    main()

usage: action [-h] [--train_or_sample TRAIN_OR_SAMPLE] [-d DATA]
              [--n_cpu N_CPU] [--resume RESUME]
              [--architecture_name ARCHITECTURE_NAME]
action: error: unrecognized arguments: -f /pfs/data5/home/kit/tm/px6680/.local/share/jupyter/runtime/kernel-74c3457a-175f-4855-82e7-01bc99f4febd.json


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [4]:
%load_ext tensorboard
%tensorboard --logdir ./supernet_functions/logs/tb