In [4]:
import os
import yaml
import itertools
import datetime
import shutil
# import hydra
from omegaconf import DictConfig
import fileinput


def load_test_config(n_lst, p_lst, m_lst, noise_levels, model_lst, cv_num, trial_num):
    config_list = []
    
    if len(n_lst) == 1:
        n_lst = n_lst[0]
    if len(p_lst) == 1:
        p_lst = p_lst[0]
    if len(m_lst) == 1:
        m_lst = m_lst[0]

    # Generate all combinations of n, p, and m
    for model in model_lst:
        config = {
            "n": n_lst,
            "p": p_lst,
            "m": m_lst,
            "noise_levels": noise_levels,
            "model": model,
            "cv_num": cv_num,
            "trial_num": trial_num,

        }
        config_list.append(config)
    return config_list


def load_model_config(
    signal_bag_percent,
    atom_bag_percent,
    select_atom_percent,
    replace_flag,
    agg_func,
    ignore_warning,
    random_seed,
    Bag_lst,
    K_lst,
):
    config_list = []

    # Generate all combinations of replace_flag
    combinations = list(itertools.product(replace_flag))

    for combination in combinations:
        replace_flag_val = combination[0]
        signal_bag_percent_val = [
            val for val in signal_bag_percent if val <= 1 or replace_flag_val
        ]
        config = {
            "signal_bag_percent": signal_bag_percent_val,
            "atom_bag_percent": atom_bag_percent,
            "select_atom_percent": select_atom_percent,
            "replace_flag": replace_flag_val,
            "agg_func": agg_func,
            "ignore_warning": ignore_warning,
            "random_seed": random_seed,
            "Bag_lst": Bag_lst,
            "K_lst": K_lst,
        }
        config_list.append(config)

    return config_list

def create_yaml_and_shell(
    test_config, model_config, configs_path,shell_script_path, hydra, core_num, walltime, mail_address
):
    os.makedirs(configs_path, exist_ok=True)
    os.makedirs(shell_script_path, exist_ok=True)
    merged_config = itertools.product(test_config, model_config)
    formatted_date = datetime.date.today().strftime("%m%d")
    shell_files = []
    for config in merged_config:
        temp_test_config, temp_model_config = config
        temp_config = {"TEST": temp_test_config, "MODEL": temp_model_config}
        if hydra:
            temp_config["hydra"] = {
                "hydra_logging": {"level":"CRITICAL"},
                "job_logging": {"level":"CRITICAL"},
                "run": {
                    "dir": "memory/" + formatted_date + "/",
                },
            }
        else:
            temp_config["output_path"] = "memory/" + formatted_date + "/"
        with_replace_flag = temp_model_config["replace_flag"]
        filename = f"{temp_test_config['model']}_{temp_test_config['n']}_{temp_test_config['p']}_{temp_test_config['m']}_{['nr','r'][with_replace_flag]}_{formatted_date}.yaml"
        temp_config["filename"] = filename
        # output_path = [no_replacement_path, with_replacement_path][
        #     temp_model_config["replace_flag"]
        # ]
        output_path = configs_path
        with open(os.path.join(output_path, filename), "w") as file:
            yaml.dump(temp_config, file)

        script_base = f"""#!/bin/sh
#SBATCH --account=stats
#SBATCH --job-name={temp_test_config['model']}_{temp_test_config['n']}_{temp_test_config['p']}_{temp_test_config['m']}_{formatted_date}
#SBATCH -c {core_num}
#SBATCH -t {walltime}
#SBATCH -C mem192
#SBATCH --mail-type=ALL 
#SBATCH --mail-user={mail_address}

module load anaconda
#Command to execute Python program
python BOMP_testing.py --config-name {filename} --config-path {configs_path}
#End of script
"""
        shell_filename = f"{temp_test_config['model']}_{temp_test_config['n']}_{temp_test_config['p']}_{temp_test_config['m']}_{['nr','r'][temp_model_config['replace_flag']]}_{formatted_date}.sh"
        shell_files.append(shell_filename)


        with open(os.path.join(shell_script_path, shell_filename), "w") as file:
            file.write(script_base)
        if with_replace_flag:
            OMP_filename = f"Baseline_{temp_test_config['n']}_{temp_test_config['p']}_{temp_test_config['m']}_{formatted_date}.sh"
            shutil.copyfile(
                os.path.join(shell_script_path, shell_filename),
                os.path.join(shell_script_path, OMP_filename),
            )
            with fileinput.FileInput(os.path.join(shell_script_path, OMP_filename), inplace=True) as file:
                for line in file:
                    print(line.replace("BOMP_testing.py", "OMP_testing.py"), end="")
            shell_files.append(OMP_filename)

    print("Don't forget to upload the configs to the server!")

    print("Don't forget to upload the shell scripts to the server!")

    print("Don't forget to submit the jobs to the server!")

    print(f"Make sure you have the right email address: {mail_address}")

    print("Just copy and paste the following commands to the terminal:")
    for shell_file in shell_files:
        print(f"sbatch {shell_file}")

In [6]:
test_config = load_test_config(
    n_lst=[600],
    p_lst=[1000],
    m_lst=[10,20],
    noise_levels=[0.02, 0.04, 0.06, 0.08, 0.1], #David
    # noise_levels=[0.12, 0.14, 0.16, 0.18, 0.2], #sty
    # noise_levels=[0.22, 0.24, 0.26, 0.28, 0.3], #zhy
    model_lst=["BOMP"],
    cv_num=5,
    trial_num=10,
)
model_config = load_model_config(
    signal_bag_percent=[0.7,0.9,1.1,1.3,1.5],
    atom_bag_percent=[0.5,0.7,0.9],
    select_atom_percent=0,
    replace_flag=[True, False],
    agg_func="weight",
    ignore_warning=True,
    random_seed=1,
    Bag_lst=[1, 50, 100, 200, 300],
    K_lst=list(range(50,401,50)),
)

# merge the two configs

configs_path = "./configs"
shell_script_path = "./shell_scripts"
hydra = False
core_num = 16
walltime = "5-00:00"
mail_address = "sz3091@columbia.edu"

Remove_configs = True
if Remove_configs:
    shutil.rmtree(configs_path)

Remove_shell_scripts = True
if Remove_shell_scripts:
    shutil.rmtree("shell_scripts")


create_yaml_and_shell(
    test_config, model_config, configs_path,shell_script_path, hydra, core_num, walltime, mail_address
)

Don't forget to upload the configs to the server!
Don't forget to upload the shell scripts to the server!
Don't forget to submit the jobs to the server!
Make sure you have the right email address: sz3091@columbia.edu
Just copy and paste the following commands to the terminal:
sbatch BOMP_600_1000_[10, 20]_r_0718.sh
sbatch Baseline_600_1000_[10, 20]_0718.sh
sbatch BOMP_600_1000_[10, 20]_nr_0718.sh
