In [1]:
import common_dl_utils as cdu
from common_dl_utils.config_creation import Config, VariableCollector
import json
import os

In [2]:
config = Config()
variable = VariableCollector()


config = Config()

# first we specify what the model should look like
config.architecture = './model_components'  # module containing all relevant classes for architectures
# NB if the classes relevant for creating the model are spread over multiple modules, this is no problem
# let config.architecture be the module that contains the "main" model class, and for all other components just specify the module
# or specify the other modules as default modules to the tools in common_jax_utils.run_utils
config.model_type = 'inr_modules.CombinedINR'

config.model_config = Config()
config.model_config.in_size = 2
config.model_config.out_size = 1
config.model_config.terms = [  # CombinedINR uses multiple MLPs and returns the sum of their outputs. These 'terms' are the MLPs
    ('inr_modules.MLPINR.from_config',{
        'hidden_size': 256,
        'num_layers': 5,
        'layer_type': variable('inr_layers.SirenLayer', 'inr_layers.GaussianINRLayer', group='method'),
        'num_splits': 3,
        'activation_kwargs': variable(
            {'w0':variable(12., 18., 24., 30., 36.)},  #  you can nest variables and put complex datastructures in their slots
            {'inverse_scale':variable(1., 2.5, 5., 7.5)}, 
            group='method'),  # by specifying a group, you make sure that the values in the group are linked, so SirenLayer wil always go with w0 and GaussianINRLayer with inverse_scale
    }),
]

# next, we set up the training loop, including the 'target_function' that we want to mimic
config.trainer_module = './inr_utils/'  
config.trainer_type = 'training.train_inr'  # NB you can use a different training loop, e.g. training.train_inr_scan instead to make it train much faster
config.loss_evaluator = 'losses.PointWiseGradLossEvaluator'
config.target_function = 'images.ContinuousImage'
config.target_function_config = {
    'image': variable('./example_data/gray_parrot_grads_scaled.npy', "example_data/gray_flower_grads_scaled.npy", group="datapoint"),#'./example_data/gray_parrot_grads_scaled.npy',
    'scale_to_01': False,
    'interpolation_method': 'images.make_piece_wise_constant_interpolation',
    'minimal_coordinate': -1.,
    'maximal_coordinate':1.,
}   
config.data_index = None
config.loss_function = 'losses.scaled_mse_loss'
config.take_grad_of_target_function = False
#config.state_update_function = ('auxiliary.ilm_updater', {num_steps = 10000})
# config.state_update_function = ('state_test_objects.py', 'counter_updater')
config.sampler = ('sampling.GridSubsetSampler',{  # samples coordinates in a fixed grid, that should in this case coincide with the pixel locations in the image
    'size': variable([2040, 1356], [240, 320], group="datapoint"),#[2040, 1356],
    'batch_size': variable(27120, 32*240, group="datapoint"),#2000,
    'allow_duplicates': False,
    'min':-1.
})

config.optimizer = 'adam'  # we'll have to add optax to the additional default modules later
# config.optimizer = 'sgd'
config.optimizer_config = {
    'learning_rate': variable(1.e-4, 1.1e-4)
}
config.steps = 160000 #changed from 40000
config.use_wandb = True

# now we want some extra things, like logging, to happen during training
# the inr_utils.training.train_inr function allows for this through callbacks.
# The callbacks we want to use can be found in inr_utils.callbacks
config.after_step_callback = 'callbacks.ComposedCallback'
config.after_step_callback_config = {
    'callbacks':[
        ('callbacks.print_loss', {'after_every':400}),  # only print the loss every 400th step
        'callbacks.report_loss',  # but log the loss to wandb after every step
        ('callbacks.MetricCollectingCallback', # this thing will help us collect metrics and log images to wandb
             {'metric_collector':'metrics.MetricCollector'}
        ),
        'callbacks.raise_error_on_nan'  # stop training if the loss becomes NaN
    ],
    'show_logs': False
}

config.after_training_callback = ('state_test_objects.py', 'after_training_callback')

config.metric_collector_config = {  # the metrics for MetricCollectingCallback / metrics.MetricCollector
    'metrics':[
        # ('metrics.PlotOnGrid2D', {'grid': 256, 'batch_size':8*256, 'frequency':'every_n_batches'}),  
        # # ^ plots the image on this fixed grid so we can visually inspect the inr on wandb
        # ('metrics.MSEOnFixedGrid', {'grid': [2040, 1356], 'batch_size':2040, 'frequency': 'every_n_batches'})
        # ^ compute the MSE with the actual image pixels
        ('metrics.ImageGradMetrics', {
            'grid':variable([2040, 1356],[240, 320], group="datapoint"), 
            'batch_size': variable(2040, 2400, group="datapoint"), 
            'frequency': 'every_n_batches'
            }),
    ],
    'batch_frequency': 400,  # compute all of these metrics every 400 batches
    'epoch_frequency': 1  # not actually used
}

#config.after_training_callback = None  # don't care for one now, but you could have this e.g. store some nice loss plots if you're not using wandb 
config.optimizer_state = None  # we're starting from scratch
config.wandb_group = variable("Siren_example_grad", "Gaussian_example_grad", group="method")
config.wandb_entity = "nld"
config.wandb_project = "inr_edu_24"

In [3]:
target_dir = "./factory_configs/test"
if not os.path.exists(target_dir):
    os.makedirs(target_dir)

config_files = []

class MyEncoder(json.JSONEncoder):
    def default(self, o):
        if isinstance(o, Config):
            return o.data
        return super().default(o)

for config_index, config_realization in enumerate(variable.realizations(config)):
    group = config_realization["wandb_group"]
    target_path = f"{target_dir}/{group}-{config_index}.yaml"
    with open(target_path, "w") as yaml_file:
        json.dump(config_realization, yaml_file, cls=MyEncoder)
    config_files.append(target_path)
    

In [None]:
# now create a slurm file that does what we want  NB you'll need to modify th account probably
# and the time
slurm_directory = "./factory_slurm/test"  # if you make this more or less nested, maybe change the "cd ../.."
if not os.path.exists(slurm_directory):
    os.makedirs(slurm_directory)

slurm_base = """
#!/bin/bash
#SBATCH --account=tesr82932
#SBATCH --time=3:00:00
#SBATCH -p gpu
#SBATCH -N 1
#SBATCH --tasks-per-node 1
#SBATCH --gpus=1
#SBATCH --output=R-%x.%j.out
module load 2023
module load Miniconda3/23.5.2-0

# >>> conda initialize >>>
# !! Contents within this block are managed by 'conda init' !!
__conda_setup="$('/sw/arch/RHEL8/EB_production/2023/software/Miniconda3/23.5.2-0/bin/conda' 'shell.bash' 'hook' 2> /dev/null)"
if [ $? -eq 0 ]; then
    eval "$__conda_setup"
else
    if [ -f "/sw/arch/RHEL8/EB_production/2023/software/Miniconda3/23.5.2-0/etc/profile.d/conda.sh" ]; then
        . "/sw/arch/RHEL8/EB_production/2023/software/Miniconda3/23.5.2-0/etc/profile.d/conda.sh"
    else
        export PATH="/sw/arch/RHEL8/EB_production/2023/software/Miniconda3/23.5.2-0/bin:$PATH"
    fi
fi
unset __conda_setup
# <<< conda initialize <<<

conda init bash
conda activate snel_bep  # conda environment name

wandblogin="$(< ./wandb.login)"  # password stored in a file, don't add this file to your git repo!
wandb login "$wandblogin"

cd ../..
echo 'Starting new experiment!';
"""

for config_file in config_files:
    slurm_script = slurm_base + f"\npython run_single.py --config={config_file}"
    slurm_file_name = (config_file.split("/")[-1].split(".")[0])+".bash"
    with open(f"{slurm_directory}/{slurm_file_name}", "w") as slurm_file:
        slurm_file.write(slurm_script)
