In [None]:
%run Include.ipynb
import argparse
import shutil
from pathlib import Path

class Argparser(object):
    
    def __init__(self, settings):
        
        # ===== Baisc =====
        branch_name          = settings["Basic"]["branch_name"]
        continue_model       = settings["Basic"]["continue_model"]
        
        # ===== Path =====
        save_path            = settings["Path"]["save_path"]
        
        # ===== Assemble and preprocess =====
        self.model_save = os.path.join(save_path, branch_name, "models")
        self.image_save = os.path.join(save_path, branch_name, "images")
        
        Path(self.model_save).mkdir(parents=True, exist_ok=True)
        Path(self.image_save).mkdir(parents=True, exist_ok=True)
        
        if not continue_model:
            shutil.rmtree(self.model_save)
            shutil.rmtree(self.image_save)
            Path(self.model_save).mkdir(parents=True, exist_ok=True)
            Path(self.image_save).mkdir(parents=True, exist_ok=True)

        self.relay_params(settings)
    
    def relay_params(self, settings):
        
        parser = argparse.ArgumentParser()
        
        # ===== Basics =====
        parser.add_argument(
            '--continue_model',
            type=bool,
            default=settings["Basic"]["continue_model"],
            help='Load the model parameters and continue training.'
        )
        parser.add_argument(
            '--model_step',
            type=int,
            default=settings["Basic"]["model_step"],
            help='The model to load in if continue training.'
        )
        parser.add_argument(
            '--data_extension',
            type=str,
            default=settings["Basic"]["data_extension"],
            help='Extensions of the files.'
        )
        parser.add_argument(
            '--dataset',
            type=str,
            default=settings["Basic"]["dataset"],
            help='Name of the target dataset.'
        )
        
        # ===== Path =====
        parser.add_argument(
            '--model_save',
            type=str,
            default=self.model_save,
            help='Path to where models are saved.'
        )
        parser.add_argument(
            '--image_save',
            type=str,
            default=self.image_save,
            help='Path to where images are saved.'
        )
        parser.add_argument(
            '--data_path',
            type=str,
            default=settings["Path"]["data_path"],
            help='Path to the input images.'
        )
        parser.add_argument(
            '--pims_path',
            type=str,
            default=settings["Path"]["pims_path"],
            help='Path to the persistence images.'
        )
        parser.add_argument(
            '--pds_path',
            type=str,
            default=settings["Path"]["pds_path"],
            help='Path to the persistence diagrams.'
        )
        
        # ===== Monitor =====
        parser.add_argument(
            '--print_step',
            type=int,
            default=settings["Monitor"]["print_step"],
            help="Print training status every print_step steps."
        )
        parser.add_argument(
            '--save_step',
            type=int,
            default=settings["Monitor"]["save_step"],
            help="Save models and images every save_step steps."
        )
        
        # ===== GPU =====
        parser.add_argument(
            '--gpu_num',
            type=int,
            default=settings["GPU"]["gpu_num"],
            help='Number of GPU to run the network on.'
        )
        parser.add_argument(
            '--gpu_enable',
            type=bool,
            default=settings["GPU"]["gpu_enable"],
            help="If to use GPU for training."
        )
        parser.add_argument(
            '--cudnn_benchmark',
            type=bool,
            default=settings["GPU"]["cudnn_benchmark"],
            help="If set cudnn_benchmark true(good for fixed size inputs)."
        )
        
        global FLAGS
        FLAGS, unparsed = parser.parse_known_args()