In [1]:
import common_dl_utils as cdu
from common_dl_utils.config_creation import Config, VariableCollector
import json
import os
import numpy as np

In [3]:
config = Config()
variable = VariableCollector()


config = Config()

# first we specify what the model should look like
config.architecture = './model_components'  # module containing all relevant classes for architectures
# NB if the classes relevant for creating the model are spread over multiple modules, this is no problem
# let config.architecture be the module that contains the "main" model class, and for all other components just specify the module
# or specify the other modules as default modules to the tools in common_jax_utils.run_utils
config.model_type = 'inr_modules.CombinedINR'

config.model_config = Config()
config.model_config.in_size = 2
config.model_config.out_size = 1
config.model_config.terms = [  # CombinedINR uses multiple MLPs and returns the sum of their outputs. These 'terms' are the MLPs
    ('inr_modules.MLPINR.from_config',{
        'hidden_size': 256,
        'num_layers': 5,
        'layer_type': variable('inr_layers.SirenLayer', 'inr_layers.GaussianINRLayer', 'inr_layers.QuadraticLayer', group='method'),
        'num_splits': 3,
        'activation_kwargs': variable(
            {'w0':variable(10., 15., 20., 25., 30., 35., group="hyperparam")},  #  you can nest variables and put complex datastructures in their slots
            {'inverse_scale':variable(10., 15., 20., 25., 30., 35., group="hyperparam")},
            {'a': variable(10., 15., 20., 25., 30., 35, group="hyperparam")},
            group='method'),  # by specifying a group, you make sure that the values in the group are linked, so SirenLayer wil always go with w0 and GaussianINRLayer with inverse_scale
    }),
]

# next, we set up the training loop, including the 'target_function' that we want to mimic

config.dataloader = 'sdf.SDFDataLoader'

config.dataloader_config = {
    "path": "example_data/xyzrgb_statuette.ply",
    "batch_size": 200000,
    "keep_aspect_ratio":True
}

config.num_cycles = 50000
config.steps_per_cycle = 200


config.loss_evaluator = "losses.SDFLossEvaluator"


config.target_function = 'sdf.SDFDataLoader' #see when config. losseval
config.target_function_config = {
    "path": "example_data/xyzrgb_statuette.ply",
    "batch_size": 200000,
    "keep_aspect_ratio":True
}


config.optimizer = 'training.OptimizerFactory.single_optimizer'#'adamw'  # we'll have to add optax to the additional default modules later

config.optimizer = 'adam'  # we'll have to add optax to the additional default modules later
# config.optimizer = 'sgd'
config.optimizer_config = {
    'learning_rate': 1.e-4
}
config.num_cycles = 50000
config.steps_per_cycle = 200



config.components_module = "./inr_utils/"
config.post_processor_type = "post_processing.PostProcessor"
config.storage_directory = variable("factory_results/Siren_sdf", "factory_results/Gaussian_sdf", "factory_results/Quadratic_sdf", group="method")

config.after_cycle_callback = 'callbacks.ComposedCallback'
config.after_cycle_callback_config = {
    'callbacks':[
        ('callbacks.print_loss', {'after_every':1}),  # only print the loss every 400th step
        'callbacks.report_loss',  # but log the loss to wandb after every step
        ('callbacks.MetricCollectingCallback', # this thing will help us collect metrics and log images to wandb
             {'metric_collector':'metrics.MetricCollector'}
        ),
        'callbacks.raise_error_on_nan'  # stop training if the loss becomes NaN
    ],
    'show_logs': False
}
config.wandb_group = variable("Siren_example_grad", "Gaussian_example_grad", "Quadratic_example_grad", group="method")
config.wandb_entity = "abdtab-tue"
config.wandb_project = "inr_edu_24"

In [4]:
variable._group_to_lengths

{'method': 3, 'hyperparam': 6, 'datapoint': 2}

In [5]:
len(list(variable.realizations(config)))

36

In [6]:
target_dir = "./factory_configs/test"
if not os.path.exists(target_dir):
    os.makedirs(target_dir)

config_files = []

class MyEncoder(json.JSONEncoder):
    def default(self, o):
        if isinstance(o, Config):
            return o.data
        return super().default(o)

for config_index, config_realization in enumerate(variable.realizations(config)):
    group = config_realization["wandb_group"]
    target_path = f"{target_dir}/{group}-{config_index}.yaml"
    with open(target_path, "w") as yaml_file:
        json.dump(config_realization, yaml_file, cls=MyEncoder)
    config_files.append(target_path)
    

In [7]:
# now create a slurm file that does what we want  NB you'll need to modify th account probably
# and the time
slurm_directory = "./factory_slurm/test"
if not os.path.exists(slurm_directory):
    os.makedirs(slurm_directory)
#chnage account and output directory and maybe conda env name
slurm_base = """#!/bin/bash
#SBATCH --account=tesr82932
#SBATCH --time=0:10:00
#SBATCH -p gpu
#SBATCH -N 1
#SBATCH --tasks-per-node 1
#SBATCH --gpus=1
#SBATCH --output=./factory_output/R-%x.%j.out
module load 2023
module load Miniconda3/23.5.2-0

# >>> conda initialize >>>
# !! Contents within this block are managed by 'conda init' !!
__conda_setup="$('/sw/arch/RHEL8/EB_production/2023/software/Miniconda3/23.5.2-0/bin/conda' 'shell.bash' 'hook' 2> /dev/null)"
if [ $? -eq 0 ]; then
    eval "$__conda_setup"
else
    if [ -f "/sw/arch/RHEL8/EB_production/2023/software/Miniconda3/23.5.2-0/etc/profile.d/conda.sh" ]; then
        . "/sw/arch/RHEL8/EB_production/2023/software/Miniconda3/23.5.2-0/etc/profile.d/conda.sh"
    else
        export PATH="/sw/arch/RHEL8/EB_production/2023/software/Miniconda3/23.5.2-0/bin:$PATH"
    fi
fi
unset __conda_setup
# <<< conda initialize <<<

conda init bash
conda activate snel_bep  # conda environment name

wandblogin="$(< ./wandb.login)"  # password stored in a file, don't add this file to your git repo!
wandb login "$wandblogin"


echo 'Starting new experiment!';
"""

for config_file in config_files:
    slurm_script = slurm_base + f"\npython run_single.py --config={config_file}"
    slurm_file_name = (config_file.split("/")[-1].split(".")[0])+".bash"
    with open(f"{slurm_directory}/{slurm_file_name}", "w") as slurm_file:
        slurm_file.write(slurm_script)
