In [21]:
import os
import subprocess
from datetime import datetime
import shutil
from copy import deepcopy
from itertools import product


In [22]:
def get_bash_cmd(script_path, slurm_args: tuple[str, str], args: list[str]):
    slurm_list = ["--" + a + "="  + b for a, b in slurm_args]
    cmd_list = ["sbatch"] + slurm_list + [script_path] + args
    return cmd_list

def run_cmd(cmd_list): subprocess.run(cmd_list, check=True)

def run_bash_script(script_path, slurm_args: tuple[str, str], args: list[str], do_run=True):
    cmd = get_bash_cmd(script_path, slurm_args, args)
    print(cmd)
    if do_run: run_cmd(cmd)

def zip_dicts(global_hp_args):
    keys = global_hp_args.keys()
    values = zip(*global_hp_args.values())  # Transpose the values
    return [dict(zip(keys, v)) for v in values]

def iterate_in_tuples(lst, n=2):
    it = iter(lst)
    tuples = list(zip(*[it] * n))  # Groups elements into n-sized tuples
    remainder = lst[len(tuples) * n:]  # Get remaining elements, if any
    if remainder:
        tuples.append(tuple(remainder))  # Add remainder as a smaller tuple
    return tuples

def generate_hp_permutations(global_hp_args):
    if not global_hp_args:
        return [{}]
    keys, values = zip(*global_hp_args.items())  # Extract keys and value lists
    permutations = [dict(zip(keys, v)) for v in product(*values)]
    return permutations
    
exp_name = "layer_sizes"
datetime_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

config_folder = "/scratch/rhm4nj/cral/cral-ginn/ginn/configs"
config_name = "adapter_config.yml"
adapter_config_path = os.path.join(config_folder, config_name)

pretrained_directory = "/scratch/rhm4nj/cral/cral-ginn/ginn/all_runs/models/experiments/2025-02-24_11-27-03_Area_1"
base_directory = "/scratch/rhm4nj/cral/cral-ginn/ginn/myvis/data_gen/S3D/Area_1"
script_path = "/scratch/rhm4nj/cral/cral-ginn/slurm_scripts/ginn_script_adapter.sh"

slurm_path = os.path.join(
    "/scratch/rhm4nj/cral/cral-ginn/slurm_scripts/adapter_logs",
    exp_name + "_" + datetime_str + "_" + base_directory.split("/")[-1]
)
slurm_out = os.path.join(slurm_path, "out")
slurm_err = os.path.join(slurm_path, "err")

model_out = os.path.join(
    "/scratch/rhm4nj/cral/cral-ginn/ginn/all_runs/models/experiments/adapter",  exp_name + "_" + datetime_str + "_" + base_directory.split("/")[-1]
)

tensorboard_dir = os.path.join(
    "/scratch/rhm4nj/cral/cral-ginn/ginn/all_runs/adapter",  exp_name + "_" + datetime_str + "_" + base_directory.split("/")[-1]
)

print("Models saved to:", model_out)
print("Check graphs at\ntensorboard --logdir", tensorboard_dir)

Models saved to: /scratch/rhm4nj/cral/cral-ginn/ginn/all_runs/models/experiments/adapter/layer_sizes_2025-02-28_22-36-50_Area_1
Check graphs at
tensorboard --logdir /scratch/rhm4nj/cral/cral-ginn/ginn/all_runs/adapter/layer_sizes_2025-02-28_22-36-50_Area_1


In [23]:
config_name = "config.yml"
target_epoch = 200

pretrained_models = []
config_paths = []
obj_names = []

for dirname in sorted(os.listdir(pretrained_directory)):
    dir_path = os.path.join(pretrained_directory, dirname)
    config_path = ""
    model_path = ""
    model_max_path = ""
    max_epoch = 0

    for root, _, files in os.walk(dir_path):
        for file in files:
            if file == config_name:
                config_path = os.path.join(root, config_name)
            
            if ".pth" not in file: continue 
            try:
                epoch = int(file.split("_")[-1].split(".")[0])
            except:
                print("Invalid file name:", file)
                continue
            
            if epoch == target_epoch:
                model_path = os.path.join(root, file)
                break

            if epoch > max_epoch:
                model_max_path = os.path.join(root, file)
                max_epoch = epoch

    if not model_path:
        if max_epoch:
            print(dirname, "Using max epoch", max_epoch)
            model_path = model_max_path
    
    if not (model_path and config_path):
        print("Nothing found for", dirname)
        continue

    pretrained_models.append(model_path)
    config_paths.append(config_path)
    obj_names.append(dirname)


Nothing found for 20_table
4_wall Using max epoch 160
7_beam Using max epoch 100
8_beam Using max epoch 170
9_door Using max epoch 160


In [24]:
group_size = 3
n_groups = 100
class_start_idx = 13
class_end_idx = 14
skip_names = ["activation_name", "obj_name", "adapter_mid_layers"]
tag_keys = ["lambda_lie_norm", "lambda_control", "lambda_descent", "lambda_recon", "activation_name", "adapter_mid_layers"]
do_perm = True
abbrv = False
do_run = True

if do_run:
    if not os.path.exists(slurm_out):
        os.makedirs(slurm_out)
    if not os.path.exists(slurm_err):
        os.makedirs(slurm_err)
    if not os.path.exists(model_out):
        os.makedirs(model_out)


global_hp_args = {
    "siren_config_path": config_paths[class_start_idx:class_end_idx],
    "pretrained_siren_path": pretrained_models[class_start_idx:class_end_idx],
    "obj_name": obj_names[class_start_idx:class_end_idx],
    # "lambda_recon": [0.25, 0.5, 1],
    # "lambda_descent": [0.5, 1],
    # "lambda_control": [1.0e-1],
    # "lambda_lie_norm": [1.0e-3, 1.0e-2],
    # "activation_name": ['tanh', 'gelu', 'relu', 'leaky_relu', 'softplus']
    # "activation_name": ['leaky_relu', 'gelu']
    "adapter_mid_layers": [[512, 256, 256, 1], [512, 256, 128, 1], [256, 128, 64, 1], [512, 128, 1], [256, 128, 1], 
        [256, 256, 1], [128, 64, 1], [256, 1]
    ]
}

if do_perm:
    # do permutations
    hp_permutations = generate_hp_permutations(global_hp_args)
    for i, perm in enumerate(hp_permutations):
        for key in perm:
            if isinstance(perm[key], float):
                hp_permutations[i][key] = format(perm[key], ".10f").rstrip('0').rstrip('.')  # Adjust decimal places as needed
else:
    # do tuples
    hp_permutations = zip_dicts(global_hp_args)

if abbrv:
    tag_keys_abrv = {key: ''.join([k[0] for k in key.split("_")]) for key in global_hp_args.keys()}
else:
    tag_keys_abrv = {key: key for key in global_hp_args.keys()}

if len(hp_permutations) > 100:
    print("ARE YOU SURE about", len(hp_permutations), "runs?")
else:
    hp_strs = []
    tags = []
    for k, hp_args in enumerate(hp_permutations):
        if len(global_hp_args) >= 1: 
            for key, val in hp_args.items():
                if not tag_keys: 
                    tag_keys = [k for k in global_hp_args]
                tag = '_'.join([tag_keys_abrv[key] + "_" + str(val) for key, val in hp_args.items() if key in tag_keys])
                for name in skip_names:
                    if abbrv:
                        tag = tag.replace(tag_keys_abrv[name] + "_", "")
                    else:
                        tag = tag.replace(name + "_", "")

        else: 
            tag = datetime_str
        
        hp_args["model_save_path"] = os.path.join(model_out, tag)
        hp_args["tensorboard_log_dir"] = os.path.join(tensorboard_dir, tag)
        hp_str = ';'.join([key + ":" + str(val) for key, val in hp_args.items()])

        hp_strs.append(hp_str)
        tags.append(tag)
            
    hp_strs_groups = iterate_in_tuples(hp_strs, n=group_size)[:n_groups]
    tag_groups = iterate_in_tuples(tags, n=group_size)[:n_groups]

    for tag_group, hp_strs_group in zip(tag_groups, hp_strs_groups):
        job_name = "_".join(tag_group)
        my_slurm_err = os.path.join(slurm_err, job_name + f"_error_%a.err")
        my_slurm_out = os.path.join(slurm_out, job_name + f"_output_%a.out")
        array_str = "0"
        if len(tag_group) > 1: 
            array_str += "-" + str(len(hp_strs_group) - 1)
        
        run_bash_script(script_path, 
            slurm_args=[("output", my_slurm_out), ("error", my_slurm_err), ("job-name", job_name), ("array", array_str)], 
            args=[adapter_config_path] + list(hp_strs_group),
            do_run=do_run
        )

    print("Individual:", len(hp_strs))
    print("Total Calls:", len(hp_strs_groups))

print("Classes", obj_names[class_start_idx:class_end_idx])

['sbatch', '--output=/scratch/rhm4nj/cral/cral-ginn/slurm_scripts/adapter_logs/layer_sizes_2025-02-28_22-36-50_Area_1/out/[512, 256, 256, 1]_[512, 256, 128, 1]_[256, 128, 64, 1]_output_%a.out', '--error=/scratch/rhm4nj/cral/cral-ginn/slurm_scripts/adapter_logs/layer_sizes_2025-02-28_22-36-50_Area_1/err/[512, 256, 256, 1]_[512, 256, 128, 1]_[256, 128, 64, 1]_error_%a.err', '--job-name=[512, 256, 256, 1]_[512, 256, 128, 1]_[256, 128, 64, 1]', '--array=0-2', '/scratch/rhm4nj/cral/cral-ginn/slurm_scripts/ginn_script_adapter.sh', '/scratch/rhm4nj/cral/cral-ginn/ginn/configs/adapter_config.yml', 'siren_config_path:/scratch/rhm4nj/cral/cral-ginn/ginn/all_runs/models/experiments/2025-02-24_11-27-03_Area_1/22_table/config.yml;pretrained_siren_path:/scratch/rhm4nj/cral/cral-ginn/ginn/all_runs/models/experiments/2025-02-24_11-27-03_Area_1/22_table/cond_siren/2025_02_24-11_55_43/2025_02_24-11_55_45-zj2kjfib_200.pth;obj_name:22_table;adapter_mid_layers:[512, 256, 256, 1];model_save_path:/scratch/rh