In [1]:
import os
import subprocess
from datetime import datetime
import shutil
from copy import deepcopy
from itertools import product
import torch
import random

In [8]:
def get_subdirectories(base_dir):
    return [os.path.join(base_dir, d) for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))]
    
def get_bash_cmd(script_path, slurm_args: tuple[str, str], args: list[str]):
    slurm_list = ["--" + a + "="  + b for a, b in slurm_args]
    cmd_list = ["sbatch"] + slurm_list + [script_path] + args
    return cmd_list

def run_cmd(cmd_list): subprocess.run(cmd_list, check=True)

def run_bash_script(script_path, slurm_args: tuple[str, str], args: list[str], do_run=True):
    cmd = get_bash_cmd(script_path, slurm_args, args)
    print(cmd)
    if do_run: run_cmd(cmd)

def zip_dicts(global_hp_args):
    keys = global_hp_args.keys()
    values = zip(*global_hp_args.values())  # Transpose the values
    return [dict(zip(keys, v)) for v in values]

def iterate_in_tuples(lst, n=2):
    it = iter(lst)
    tuples = list(zip(*[it] * n))  # Groups elements into n-sized tuples
    remainder = lst[len(tuples) * n:]  # Get remaining elements, if any
    if remainder:
        tuples.append(tuple(remainder))  # Add remainder as a smaller tuple
    return tuples

def generate_hp_permutations(global_hp_args):
    if not global_hp_args:
        return [{}]
    keys, values = zip(*global_hp_args.items())  # Extract keys and value lists
    permutations = [dict(zip(keys, v)) for v in product(*values)]
    return permutations
    
datetime_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
CUSTOM_NAME = "ablat_novxx"
if CUSTOM_NAME:
    print("Are you SURE you want a custom name", CUSTOM_NAME)

config_folder = "/scratch/rhm4nj/cral/cral-ginn/ginn/configs"

config_name = "config_3dis_adapter.yml"
data_directory = "/scratch/rhm4nj/cral/cral-ginn/ginn/myvis/data_gen/replica/room_0_objects"
script_path = "/scratch/rhm4nj/cral/cral-ginn/slurm_scripts/ginn_script.sh"
slurm_path = os.path.join(
    "/scratch/rhm4nj/cral/cral-ginn/slurm_scripts/logs",
    datetime_str + "_" + data_directory.split("/")[-1]
)

slurm_out = os.path.join(slurm_path, "out")
slurm_err = os.path.join(slurm_path, "err")

model_out = os.path.join(
    "/scratch/rhm4nj/cral/cral-ginn/ginn/all_runs/models/experiments", datetime_str + "_" + data_directory.split("/")[-1] + "_" + CUSTOM_NAME
)

Are you SURE you want a custom name ablat_novxx


In [9]:
# group_size = torch.cuda.device_count()
group_size = 1
# gpu_types = ['a100', 'v100', 'h200', 'a6000', 'a40']
my_gpu = ""
n_groups = 100
class_start_idx = 0
class_end_idx = 100
skip_names = ["activation_name", "obj_name", "adapter_mid_layers", "min_outer_val", "max_domain_val", 'layers']
tag_keys = ["lambda_vxx", "lambda_vx", "lambda_bc", "lambda_scc", "lambda_vxx", "lambda_vx", "lambda_descent", "lambda_outer_env", "obj_name", "lambda_dom", "max_domain_val", "env_val_tol", "layers"]
do_perm = False
abbrv = False
do_run = True

all_obj_folders = sorted(get_subdirectories(data_directory), key=lambda s: int(s.split("/")[-1].split('_')[0]))
all_obj_names = [folder.split("/")[-1] for folder in all_obj_folders]
print(all_obj_names)
obj_folders = all_obj_folders[class_start_idx:class_end_idx]
obj_names = all_obj_names[class_start_idx:class_end_idx]

print(len(obj_names), obj_names)
print("Classes", obj_names)

global_hp_args = {
    "dataset_dir": obj_folders,
    "obj_name": obj_names,
    # "env_val_tol": [0.1, 0.2, 0.5],
    # "lambda_dom": [1],
    # "max_domain_val": [-1, -2, -5],
    # "lambda_descent": [1, 1.0e-1],
    # "lambda_outer_env": [2, 5, 10],
    # "lambda_descent": [1],
    # "lambda_small_control": [1e-3],
    # "lambda_descent": [0] * len(obj_names),
    "lambda_vxx": [0] * len(obj_names),
    # "lambda_vx": [0] * len(obj_names),
    # "cbf_lambda": [1.0e-5, -1.0e-5],
    # "min_control_norm": [0.5]
    # "min_outer_val": [1.0e-4, 1.0e-3, 1.0e-2, 0],
    # "lambda_bc": [10],
    # "lambda_scc": [1, 0.5, 0.1, 0.05],
    # "lambda_vxx": [1.0e-2, 1.0e-3, 1.0e-4],
    # "lambda_vx": [1.0e-3, 1.0e-4, 1.0e-5],
    # "controller_step_range": [[4, 5]],
    # "layers": [[4, 256, 256, 256, 256, 256, 1], [4, 256, 256, 256, 256, 1], [4, 256, 256, 256, 1], [4, 256, 256, 1]],
}

if do_run:
    if not os.path.exists(slurm_out):
        os.makedirs(slurm_out)
    if not os.path.exists(slurm_err):
        os.makedirs(slurm_err)
    if not os.path.exists(model_out):
        os.makedirs(model_out)

    shutil.copy2(os.path.join(config_folder, config_name), slurm_path)

for key in global_hp_args:
    for i in range(len(global_hp_args[key])):
        if isinstance(global_hp_args[key][i], float):
            global_hp_args[key][i] = format(global_hp_args[key][i], ".10f").rstrip('0').rstrip('.')  # Adjust decimal places as needed

if do_perm:
    # do permutations
    hp_permutations = generate_hp_permutations(global_hp_args)
    for i, perm in enumerate(hp_permutations):
        for key in perm:
            if isinstance(perm[key], float):
                hp_permutations[i][key] = perm[key]
else:
    # do tuples
    hp_permutations = zip_dicts(global_hp_args)

if abbrv:
    tag_keys_abrv = {key: ''.join([k[0] for k in key.split("_")]) for key in global_hp_args.keys()}
else:
    tag_keys_abrv = {key: key for key in global_hp_args.keys()}

if len(hp_permutations) > 100:
    print("ARE YOU SURE about", len(hp_permutations), "runs?")

else:
    hp_strs = []
    tags = []
    for k, hp_args in enumerate(hp_permutations):
        if len(global_hp_args) >= 1: 
            for key, val in hp_args.items():
                if not tag_keys: 
                    tag_keys = [k for k in global_hp_args]
                tag = '_'.join([tag_keys_abrv[key] + "_" + str(val) for key, val in hp_args.items() if key in tag_keys])
                for name in skip_names:
                    if abbrv:
                        tag = tag.replace(tag_keys_abrv[name] + "_", "")
                    else:
                        tag = tag.replace(name + "_", "")

        else: 
            tag = datetime_str
        
        hp_args["model_save_path"] = os.path.join(model_out, tag)
        hp_args["wandb_experiment_name"] = tag
        hp_str = ';'.join([key + ":" + str(val) for key, val in hp_args.items()])

        hp_strs.append(hp_str)
        tags.append(tag)
            
    hp_strs_groups = iterate_in_tuples(hp_strs, n=group_size)[:n_groups]
    tag_groups = iterate_in_tuples(tags, n=group_size)[:n_groups]

    for tag_group, hp_strs_group in zip(tag_groups, hp_strs_groups):
        job_name = "_".join(tag_group)
        my_slurm_err = os.path.join(slurm_err, job_name + f"_error_%a.err")
        my_slurm_out = os.path.join(slurm_out, job_name + f"_output_%a.out")
        array_str = "0"
        if len(tag_group) > 1: 
            array_str += "-" + str(len(hp_strs_group) - 1)
        
        print(job_name)
        ("constraint", "a100|h200|v100"),
        my_gpu_str = ""
        if my_gpu:
            my_gpu_str = ":" + my_gpu
        gpu_str = "gpu" + my_gpu_str
        run_bash_script(script_path, 
            slurm_args=[("output", my_slurm_out), ("error", my_slurm_err), ("job-name", job_name), 
                ("array", array_str),  ("gres", gpu_str),
            ], 
            args=[config_name] + list(hp_strs_group),
            do_run=do_run
        )

    print("Individual:",  len(hp_strs))
    print("Total Calls:", len(hp_strs_groups))


['0_wall', '1_lamp', '2_table', '3_table', '4_table', '5_stool', '6_chair', '7_sofa', '8_stool', '9_chair', '10_table', '11_lamp', '12_ceiling', '13_floor']
14 ['0_wall', '1_lamp', '2_table', '3_table', '4_table', '5_stool', '6_chair', '7_sofa', '8_stool', '9_chair', '10_table', '11_lamp', '12_ceiling', '13_floor']
Classes ['0_wall', '1_lamp', '2_table', '3_table', '4_table', '5_stool', '6_chair', '7_sofa', '8_stool', '9_chair', '10_table', '11_lamp', '12_ceiling', '13_floor']
0_wall_lambda_vxx_0
['sbatch', '--output=/scratch/rhm4nj/cral/cral-ginn/slurm_scripts/logs/2025-05-01_00-21-14_room_0_objects/out/0_wall_lambda_vxx_0_output_%a.out', '--error=/scratch/rhm4nj/cral/cral-ginn/slurm_scripts/logs/2025-05-01_00-21-14_room_0_objects/err/0_wall_lambda_vxx_0_error_%a.err', '--job-name=0_wall_lambda_vxx_0', '--array=0', '--gres=gpu', '/scratch/rhm4nj/cral/cral-ginn/slurm_scripts/ginn_script.sh', 'config_3dis_adapter.yml', 'dataset_dir:/scratch/rhm4nj/cral/cral-ginn/ginn/myvis/data_gen/repl

In [4]:
# def iterate_in_pairs(lst):
#     it = iter(lst)
#     return list(zip(it, it)) + ([tuple([lst[-1]])] if len(lst) % 2 else [])

# # if not os.path.exists(slurm_out):
# #     os.makedirs(slurm_out)
# # if not os.path.exists(slurm_err):
# #     os.makedirs(slurm_err)
# # if not os.path.exists(model_out):
# #     os.makedirs(model_out)
# # shutil.copy2(os.path.join(config_folder, config_name), slurm_path)

# n = 20
# subdirs = get_subdirectories(data_directory)[:n]

# global_hp_args = {
#     'lambda_bound': 1e-4,
# }

# for my_inpdir_pairs in iterate_in_pairs(subdirs):
#     my_names = []
#     args = [config_name]
#     hp_args = deepcopy(global_hp_args)

#     for my_inpdir in my_inpdir_pairs:
#         my_name =  my_inpdir.split("/")[-1]
#         my_model_out = os.path.join(model_out,my_name)

#         args.append(my_inpdir)
#         args.append(my_model_out)
#         my_names.append(my_name)

#         hp_args["dataset_dir"] = my_inpdir
#         hp_args["model_save_path"] = my_model_out

#     hp_str = ';'.join([key + ":" + str(val) for key, val in hp_args.items()])
#     print(hp_str)
    
#     my_name = '_'.join(my_names)
#     my_slurm_out = os.path.join(slurm_out, f"{my_name}_%a.out")
#     my_slurm_err = os.path.join(slurm_err, f"{my_name}_%a.err")

#     array_arg = "0"
#     if len(my_inpdir_pairs) > 1:
#         array_arg = "0-1"

#     run_bash_script(script_path, 
#         slurm_args=[("output", my_slurm_out), ("error", my_slurm_err), ("job-name", my_name + "_" + datetime_str), ("array", array_arg)], 
#         args=args
#     )