diff --git a/varipeps/__init__.py b/varipeps/__init__.py index f43467b..3ec7ab2 100644 --- a/varipeps/__init__.py +++ b/varipeps/__init__.py @@ -20,11 +20,4 @@ jax_config.update("jax_enable_x64", True) -from tqdm_loggable.tqdm_logging import tqdm_logging -import datetime - -tqdm_logging.set_log_rate(datetime.timedelta(seconds=60)) - -del datetime -del tqdm_logging del jax_config diff --git a/varipeps/config.py b/varipeps/config.py index bf48686..6119382 100644 --- a/varipeps/config.py +++ b/varipeps/config.py @@ -1,11 +1,12 @@ from dataclasses import dataclass from enum import Enum, IntEnum, auto, unique +from typing import TypeVar, Tuple, Any, Type, NoReturn +import logging import numpy as np from jax.tree_util import register_pytree_node_class -from typing import TypeVar, Tuple, Any, Type, NoReturn T_VariPEPS_Config = TypeVar("T_VariPEPS_Config", bound="VariPEPS_Config") @@ -54,6 +55,15 @@ class Slurm_Restart_Mode(IntEnum): AUTOMATIC_RESTART = auto() #: Write restart script and start new slurm job with it +@unique +class LogLevel(IntEnum): + OFF = 0 + ERROR = logging.ERROR + WARNING = logging.WARNING + INFO = logging.INFO + DEBUG = logging.DEBUG + + @dataclass @register_pytree_node_class class VariPEPS_Config: @@ -225,6 +235,27 @@ class VariPEPS_Config: Type of wavevector to be used (only positive/symmetric interval/...). slurm_restart_mode (:obj:`Slurm_Restart_Mode`): Mode of operation to restart slurm job if maximal runtime is reached. + log_level_global (:obj:`LogLevel`): + Global logging level for the 'varipeps' package logger. + log_level_optimizer (:obj:`LogLevel`): + Logging level for 'varipeps.optimizer'. + log_level_ctmrg (:obj:`LogLevel`): + Logging level for 'varipeps.ctmrg'. + log_level_line_search (:obj:`LogLevel`): + Logging level for 'varipeps.line_search'. + log_level_expectation (:obj:`LogLevel`): + Logging level for 'varipeps.expectation'. + log_to_console (:obj:`bool`): + Enable standard console logging (StreamHandler). + Ignored when :obj:`VariPEPS_Config.log_tqdm` is True. + log_to_file (:obj:`bool`): + Enable logging to file. + log_file (:obj:`str`): + Filename for logging to file (used when :obj:`VariPEPS_Config.log_to_file` is True). + log_tqdm (:obj:`bool`): + Enable tqdm-based console logging. If True, messages from + 'varipeps.optimizer' update a tqdm progress bar, while other modules + log via tqdm.write. File logging settings still apply. """ # AD config @@ -310,6 +341,17 @@ class VariPEPS_Config: # Slurm slurm_restart_mode: Slurm_Restart_Mode = Slurm_Restart_Mode.WRITE_NEED_RESTART_FILE + # Logging configuration + log_level_global: LogLevel = LogLevel.INFO + log_level_optimizer: LogLevel = LogLevel.INFO + log_level_ctmrg: LogLevel = LogLevel.INFO + log_level_line_search: LogLevel = LogLevel.INFO + log_level_expectation: LogLevel = LogLevel.INFO + log_to_console: bool = True + log_to_file: bool = False + log_file: str = "varipeps.log" + log_tqdm: bool = False #: Enable tqdm-based console logging + def update(self, name: str, value: Any) -> NoReturn: self.__setattr__(name, value) @@ -346,12 +388,33 @@ def __setattr__(self, name: str, value: Any) -> NoReturn: elif ( field.type is bool and hasattr(value, "dtype") - and np.isdtype(value.dtype, np.bool) + and np.issubdtype(value.dtype, np.bool_) and value.size == 1 ): if value.ndim > 0: value = value.reshape(-1)[0] value = bool(value) + elif isinstance(field.type, type) and issubclass(field.type, Enum): + # Accept ints/np.int64 or enum names for Enum fields + if isinstance(value, field.type): + pass + elif isinstance(value, (int,)) or ( + hasattr(value, "dtype") + and np.issubdtype(value.dtype, np.integer) + and value.size == 1 + ): + if hasattr(value, "ndim") and value.ndim > 0: + value = value.reshape(-1)[0] + value = field.type(int(value)) + elif isinstance(value, str): + try: + value = field.type[value] + except KeyError: + value = field.type(int(value)) + else: + raise TypeError( + f"Type mismatch for option '{name}', got '{type(value)}', expected '{field.type}'." + ) else: raise TypeError( f"Type mismatch for option '{name}', got '{type(value)}', expected '{field.type}'." @@ -395,6 +458,7 @@ class ConfigModuleWrapper: "Projector_Method", "Wavevector_Type", "Slurm_Restart_Mode", + "LogLevel", "VariPEPS_Config", "config", } diff --git a/varipeps/ctmrg/routine.py b/varipeps/ctmrg/routine.py index 243a24e..f00e963 100644 --- a/varipeps/ctmrg/routine.py +++ b/varipeps/ctmrg/routine.py @@ -5,6 +5,11 @@ from jax import jit, custom_vjp, vjp, tree_util from jax.lax import cond, while_loop import jax.debug as jdebug +import logging +import time +import jax + +logger = logging.getLogger("varipeps.ctmrg") from varipeps import varipeps_config, varipeps_global_state from varipeps.peps import PEPS_Tensor, PEPS_Tensor_Split_Transfer, PEPS_Unit_Cell @@ -515,9 +520,8 @@ def corner_svd_func(old, new, old_corner, conv_eps, config): eps, config, ) - - if config.ctmrg_print_steps: - debug_print("CTMRG: {}: {}", count, measure) + if logger.isEnabledFor(logging.DEBUG): + jax.debug.callback(lambda cnt, msr: logger.debug(f"CTMRG: Step {cnt}: {msr}"), count, measure, ordered=True) if config.ctmrg_verbose_output: jax.debug.callback(print_verbose, verbose_data, ordered=True) @@ -620,9 +624,9 @@ def calc_ctmrg_env( best_norm_smallest_S = None best_truncation_eps = None have_been_increased = False - while True: tmp_count = 0 + t0 = time.perf_counter() corner_singular_vals = None while tmp_count < varipeps_config.ctmrg_max_steps and ( @@ -720,6 +724,17 @@ def calc_ctmrg_env( else: converged = False end_count = tmp_count + + if not converged and logger.isEnabledFor(logging.WARNING): + logger.warning( + "CTMRG: ❌ did not converge, took %.2f seconds. (Steps: %d, Smallest SVD Norm: %.3e)", + time.perf_counter() - t0, end_count, norm_smallest_S + ) + elif logger.isEnabledFor(logging.INFO): + logger.info( + "CTMRG: ✅ converged, took %.2f seconds. (Steps: %d, Smallest SVD Norm: %.3e)", + time.perf_counter() - t0, end_count, norm_smallest_S + ) if converged and ( working_unitcell[0, 0][0][0].chi > best_chi or best_result is None @@ -751,9 +766,9 @@ def calc_ctmrg_env( working_unitcell = working_unitcell.change_chi(new_chi) initial_unitcell = initial_unitcell.change_chi(new_chi) - if varipeps_config.ctmrg_print_steps: - debug_print( - "CTMRG: Increasing chi to {} since smallest SVD Norm was {}.", + if logger.isEnabledFor(logging.INFO): + logger.info( + "Increasing chi to %d since smallest SVD Norm was %.3e.", new_chi, norm_smallest_S, ) @@ -785,9 +800,9 @@ def calc_ctmrg_env( if not new_chi in already_tried_chi: working_unitcell = working_unitcell.change_chi(new_chi) - if varipeps_config.ctmrg_print_steps: - debug_print( - "CTMRG: Decreasing chi to {} since smallest SVD Norm was {} or routine did not converge.", + if logger.isEnabledFor(logging.INFO): + logger.info( + "Decreasing chi to %d since smallest SVD Norm was %.3e or routine did not converge.", new_chi, norm_smallest_S, ) @@ -809,9 +824,9 @@ def calc_ctmrg_env( new_truncation_eps <= varipeps_config.ctmrg_increase_truncation_eps_max_value ): - if varipeps_config.ctmrg_print_steps: - debug_print( - "CTMRG: Increasing SVD truncation eps to {}.", + if logger.isEnabledFor(logging.INFO): + logger.info( + "Increasing SVD truncation eps to %.1e.", new_truncation_eps, ) varipeps_global_state.ctmrg_effective_truncation_eps = ( @@ -884,6 +899,8 @@ def calc_ctmrg_env_fwd( Internal helper function of custom VJP to calculate the values in the forward sweep. """ + if logger.isEnabledFor(logging.DEBUG): + logger.debug("Custom VJP: Starting forward CTMRG calculation.") new_unitcell, last_truncation_eps, norm_smallest_S = calc_ctmrg_env_custom_rule( peps_tensors, unitcell, _return_truncation_eps=True ) @@ -937,8 +954,8 @@ def _ctmrg_rev_while_body(carry): count += 1 - if config.ad_custom_print_steps: - debug_print("Custom VJP: {}: {}", count, measure) + if logger.isEnabledFor(logging.DEBUG): + jax.debug.callback(lambda cnt, msr: logger.debug(f"Custom VJP: Step {cnt}: {msr}"), count, measure, ordered=True) if config.ad_custom_verbose_output: jax.debug.callback(print_verbose, verbose_data, ordered=True, ad=True) @@ -1009,17 +1026,31 @@ def calc_ctmrg_env_rev( Internal helper function of custom VJP to calculate the gradient in the backward sweep. """ + if logger.isEnabledFor(logging.DEBUG): + logger.debug("Custom VJP: Starting reverse CTMRG calculation.") unitcell_bar, _ = input_bar peps_tensors, new_unitcell, input_unitcell, last_truncation_eps = res varipeps_global_state.ctmrg_effective_truncation_eps = last_truncation_eps + if logger.isEnabledFor(logging.WARNING): + t0 = time.perf_counter() t_bar, converged, end_count = _ctmrg_rev_workhorse( peps_tensors, new_unitcell, unitcell_bar, varipeps_config, varipeps_global_state ) varipeps_global_state.ctmrg_effective_truncation_eps = None + if not converged and logger.isEnabledFor(logging.WARNING): + logger.warning( + "Custom VJP: ❌ did not converge, took %.2f seconds. (Steps: %d)", + time.perf_counter() - t0, end_count + ) + elif logger.isEnabledFor(logging.INFO): + logger.info( + "Custom VJP: ✅ converged, took %.2f seconds. (Steps: %d)", + time.perf_counter() - t0, end_count + ) if end_count == varipeps_config.ad_custom_max_steps and not converged: raise CTMRGGradientNotConvergedError diff --git a/varipeps/ctmrg/structure_factor_routine.py b/varipeps/ctmrg/structure_factor_routine.py index 29ec384..585e576 100644 --- a/varipeps/ctmrg/structure_factor_routine.py +++ b/varipeps/ctmrg/structure_factor_routine.py @@ -5,6 +5,10 @@ from jax import jit, custom_vjp, vjp, tree_util from jax.lax import cond, while_loop import jax.debug as jdebug +import logging +import time + +logger = logging.getLogger("varipeps.ctmrg") from varipeps import varipeps_config, varipeps_global_state from varipeps.peps import PEPS_Tensor, PEPS_Unit_Cell @@ -125,8 +129,8 @@ def _ctmrg_body_func_structure_factor(carry): measure = jnp.linalg.norm(corner_svd - last_corner_svd) converged = measure < eps - if config.ctmrg_print_steps: - debug_print("CTMRG: {}: {}", count, measure) + if logger.isEnabledFor(logging.DEBUG): + jax.debug.callback(lambda cnt, msr: logger.debug(f"CTMRG: Step {cnt}: {msr}"), count, measure, ordered=True) if config.ctmrg_verbose_output: for ti, ctm_enum_i, diff in verbose_data: debug_print( @@ -244,6 +248,7 @@ def calc_ctmrg_env_structure_factor( norm_smallest_S = jnp.nan already_tried_chi = {working_unitcell[0, 0][0][0].chi} + t0 = time.perf_counter() while True: tmp_count = 0 corner_singular_vals = None @@ -304,6 +309,17 @@ def calc_ctmrg_env_structure_factor( ) ) + if not converged and logger.isEnabledFor(logging.WARNING): + logger.warning( + "CTMRG (SF): ❌ did not converge, took %.2f seconds. (Steps: %d, Smallest SVD Norm: %.3e)", + time.perf_counter() - t0, end_count, norm_smallest_S + ) + elif logger.isEnabledFor(logging.INFO): + logger.info( + "CTMRG (SF): ✅ converged, took %.2f seconds. (Steps: %d, Smallest SVD Norm: %.3e)", + time.perf_counter() - t0, end_count, norm_smallest_S + ) + current_truncation_eps = ( varipeps_config.ctmrg_truncation_eps if varipeps_global_state.ctmrg_effective_truncation_eps is None @@ -326,15 +342,14 @@ def calc_ctmrg_env_structure_factor( working_unitcell = working_unitcell.change_chi(new_chi) initial_unitcell = initial_unitcell.change_chi(new_chi) - if varipeps_config.ctmrg_print_steps: - debug_print( - "CTMRG: Increasing chi to {} since smallest SVD Norm was {}.", + if logger.isEnabledFor(logging.INFO): + logger.info( + "CTMRG (SF): Increasing chi to %d since smallest SVD Norm was %.3e.", new_chi, norm_smallest_S, ) already_tried_chi.add(new_chi) - continue elif ( varipeps_config.ctmrg_heuristic_decrease_chi @@ -351,15 +366,14 @@ def calc_ctmrg_env_structure_factor( if not new_chi in already_tried_chi: working_unitcell = working_unitcell.change_chi(new_chi) - if varipeps_config.ctmrg_print_steps: - debug_print( - "CTMRG: Decreasing chi to {} since smallest SVD Norm was {}.", + if logger.isEnabledFor(logging.INFO): + logger.info( + "CTMRG (SF): Decreasing chi to %d since smallest SVD Norm was %.3e.", new_chi, norm_smallest_S, ) already_tried_chi.add(new_chi) - continue if ( @@ -375,9 +389,9 @@ def calc_ctmrg_env_structure_factor( new_truncation_eps <= varipeps_config.ctmrg_increase_truncation_eps_max_value ): - if varipeps_config.ctmrg_print_steps: - debug_print( - "CTMRG: Increasing SVD truncation eps to {}.", + if logger.isEnabledFor(logging.INFO): + logger.info( + "CTMRG (SF): Increasing SVD truncation eps to %g.", new_truncation_eps, ) varipeps_global_state.ctmrg_effective_truncation_eps = ( diff --git a/varipeps/optimization/line_search.py b/varipeps/optimization/line_search.py index 6a974b9..d70a5b7 100644 --- a/varipeps/optimization/line_search.py +++ b/varipeps/optimization/line_search.py @@ -1,6 +1,5 @@ import enum -from tqdm_loggable.auto import tqdm import jax import jax.numpy as jnp @@ -14,6 +13,9 @@ from varipeps.expectation import Expectation_Model from varipeps.mapping import Map_To_PEPS_Model from varipeps.utils.debug_print import debug_print +import logging + +logger = logging.getLogger("varipeps.line_search") from .inner_function import ( calc_ctmrg_expectation, @@ -443,6 +445,7 @@ def line_search( additional_input, enforce_elementwise_convergence=enforce_elementwise_convergence, ) + logger.info("🔎 Line search step %d, E=%.6f, alpha=%.3e", count + 1, new_value, alpha) if new_unitcell[0, 0][0][0].chi > unitcell[0, 0][0][0].chi: tmp_value = current_value @@ -463,10 +466,11 @@ def line_search( else: unitcell = unitcell.change_chi(new_unitcell[0, 0][0][0].chi) - debug_print( - "Line search: Recalculate original unitcell with higher chi {}.", - new_unitcell[0, 0][0][0].chi, - ) + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "Line search: Recalculate original unitcell with higher chi %s.", + new_unitcell[0, 0][0][0].chi, + ) if varipeps_config.ad_use_custom_vjp: ( @@ -534,6 +538,7 @@ def line_search( additional_input, calc_preconverged=True, ) + logger.info("🔎 Line search step %d, E=%.8f, alpha=%.8f", count + 1, new_value, alpha) new_gradient = [elem.conj() for elem in new_gradient_seq] if new_unitcell[0, 0][0][0].chi > unitcell[0, 0][0][0].chi: @@ -554,11 +559,11 @@ def line_search( ) = cache_original_unitcell[new_unitcell[0, 0][0][0].chi] else: unitcell = unitcell.change_chi(new_unitcell[0, 0][0][0].chi) - - debug_print( - "Line search: Recalculate original unitcell with higher chi {}.", - new_unitcell[0, 0][0][0].chi, - ) + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "Line search: Recalculate original unitcell with higher chi %s.", + new_unitcell[0, 0][0][0].chi, + ) if varipeps_config.ad_use_custom_vjp: ( @@ -1002,7 +1007,7 @@ def line_search( ) if alpha <= 0: - tqdm.write("Found negative alpha in secant operation!") + logger.warning("Found negative alpha in secant operation!") hz_secant_alpha = alpha @@ -1120,6 +1125,7 @@ def line_search( jax.clear_caches() if count == varipeps_config.line_search_max_steps: + logger.warn("❗ No suitable step size found in line search!") raise NoSuitableStepSizeError(f"Count {count}, Last alpha {alpha}") return ( diff --git a/varipeps/optimization/optimizer.py b/varipeps/optimization/optimizer.py index 0038cbe..46c0ba7 100644 --- a/varipeps/optimization/optimizer.py +++ b/varipeps/optimization/optimizer.py @@ -11,8 +11,6 @@ from scipy.optimize import OptimizeResult -from tqdm_loggable.auto import tqdm - import h5py import numpy as np @@ -23,6 +21,10 @@ from jax.lax import scan from jax.flatten_util import ravel_pytree +import logging + +logger = logging.getLogger("varipeps.optimizer") + from varipeps import varipeps_config, varipeps_global_state from varipeps.config import Optimizing_Methods, Slurm_Restart_Mode from varipeps.peps import PEPS_Unit_Cell @@ -32,6 +34,7 @@ from varipeps.ctmrg import CTMRGNotConvergedError, CTMRGGradientNotConvergedError from varipeps.utils.random import PEPS_Random_Number_Generator from varipeps.utils.slurm import SlurmUtils +from varipeps.utils.logging_config import ensure_logging_configured from .inner_function import ( calc_ctmrg_expectation, @@ -202,11 +205,14 @@ def autosave_function( counter: Optional[Union[int, str]] = None, auxiliary_data: Optional[Dict[str, Any]] = None, ) -> None: + t0 = time.perf_counter() if counter is not None: unitcell.save_to_file( f"{str(filename)}.{counter}", auxiliary_data=auxiliary_data ) + logger.debug(f"💾 Autosaving to {str(filename)}.{counter}, took {time.perf_counter() - t0:.2f} sec") else: + logger.debug(f"💾 Autosaving to {str(filename)}, took {time.perf_counter() - t0:.2f} sec") unitcell.save_to_file(filename, auxiliary_data=auxiliary_data) @@ -231,6 +237,7 @@ def autosave_function_restartable( signal_reset_descent_dir, ) -> None: state_filename = os.environ.get("VARIPEPS_STATE_FILE") + t0 = time.perf_counter() if state_filename is None: state_filename = f"{str(filename)}.restartable" with h5py.File(state_filename, "w", libver=("earliest", "v110")) as f: @@ -361,6 +368,7 @@ def autosave_function_restartable( compression="gzip", compression_opts=6, ) + logger.debug(f"💾 Restartable autosaving to {str(state_filename)}, took {time.perf_counter() - t0:.2f} sec") def _autosave_wrapper( @@ -493,6 +501,8 @@ def optimize_peps_network( final expectation value. See the type definition for other possible fields. """ + ensure_logging_configured() + logger.info("🛠️ Starting optimization ... ") rng = PEPS_Random_Number_Generator.get_generator(backend="jax") def random_noise(a): @@ -604,60 +614,237 @@ def random_noise(a): slurm_restart_written = False slurm_new_job_id = None - with tqdm(desc="Optimizing PEPS state", initial=count) as pbar: - while count < varipeps_config.optimizer_max_steps: - runtime_start = time.perf_counter() + while count < varipeps_config.optimizer_max_steps: + runtime_start = time.perf_counter() - chi_before_ctmrg = working_unitcell[0, 0][0][0].chi - try: - if varipeps_config.ad_use_custom_vjp: - ( - working_value, - (working_unitcell, _), - ), working_gradient_seq = calc_ctmrg_expectation_custom_value_and_grad( - working_tensors, - working_unitcell, - expectation_func, - convert_to_unitcell_func, - additional_input, + chi_before_ctmrg = working_unitcell[0, 0][0][0].chi + try: + if varipeps_config.ad_use_custom_vjp: + ( + working_value, + (working_unitcell, _), + ), working_gradient_seq = calc_ctmrg_expectation_custom_value_and_grad( + working_tensors, + working_unitcell, + expectation_func, + convert_to_unitcell_func, + additional_input, + ) + else: + ( + working_value, + (working_unitcell, _), + ), working_gradient_seq = calc_preconverged_ctmrg_value_and_grad( + working_tensors, + working_unitcell, + expectation_func, + convert_to_unitcell_func, + additional_input, + calc_preconverged=(count == 0), + ) + except (CTMRGNotConvergedError, CTMRGGradientNotConvergedError) as e: + varipeps_global_state.ctmrg_projector_method = None + + if random_noise_retries == 0: + return OptimizeResult( + success=False, + message=str(type(e)), + x=working_tensors, + fun=working_value, + unitcell=working_unitcell, + nit=count, + max_trunc_error_list=max_trunc_error_list, + step_energies=step_energies, + step_chi=step_chi, + step_conv=step_conv, + step_runtime=step_runtime, + best_run=0, + ) + elif ( + random_noise_retries + >= varipeps_config.optimizer_random_noise_max_retries + ): + working_value = jnp.inf + break + else: + if isinstance(input_tensors, PEPS_Unit_Cell) or ( + isinstance(input_tensors, collections.abc.Sequence) + and isinstance(input_tensors[0], PEPS_Unit_Cell) + ): + working_tensors = ( + cast( + List[jnp.ndarray], + [i.tensor for i in best_unitcell.get_unique_tensors()], + ) + + best_tensors[best_unitcell.get_len_unique_tensors() :] + ) + + working_tensors = [random_noise(i) for i in working_tensors] + + working_tensors_obj = [ + e.replace_tensor(working_tensors[i]) + for i, e in enumerate(best_unitcell.get_unique_tensors()) + ] + + working_unitcell = best_unitcell.replace_unique_tensors( + working_tensors_obj ) else: + working_tensors = [random_noise(i) for i in best_tensors] + working_unitcell = None + + descent_dir = None + working_gradient = None + signal_reset_descent_dir = True + count = 0 + random_noise_retries += 1 + old_descent_dir = descent_dir + old_gradient = working_gradient + + step_energies[random_noise_retries] = [] + step_chi[random_noise_retries] = [] + step_conv[random_noise_retries] = [] + max_trunc_error_list[random_noise_retries] = [] + step_runtime[random_noise_retries] = [] + + continue + + if working_unitcell[0, 0][0][0].chi != chi_before_ctmrg: + jax.clear_caches() + + working_gradient = [elem.conj() for elem in working_gradient_seq] + + if signal_reset_descent_dir: + if varipeps_config.optimizer_method is Optimizing_Methods.BFGS: + bfgs_prefactor = ( + 2 if any(jnp.iscomplexobj(t) for t in working_tensors) else 1 + ) + bfgs_B_inv = jnp.eye( + bfgs_prefactor * sum([t.size for t in working_tensors]) + ) + elif varipeps_config.optimizer_method is Optimizing_Methods.L_BFGS: + l_bfgs_x_cache = deque( + maxlen=varipeps_config.optimizer_l_bfgs_maxlen + 1 + ) + l_bfgs_grad_cache = deque( + maxlen=varipeps_config.optimizer_l_bfgs_maxlen + 1 + ) + + if varipeps_config.optimizer_method is Optimizing_Methods.STEEPEST: + descent_dir = [-elem for elem in working_gradient] + elif varipeps_config.optimizer_method is Optimizing_Methods.CG: + if count == 0 or signal_reset_descent_dir: + descent_dir = [-elem for elem in working_gradient] + else: + descent_dir, beta = _cg_workhorse( + working_gradient, old_gradient, old_descent_dir + ) + elif varipeps_config.optimizer_method is Optimizing_Methods.BFGS: + if count == 0 or signal_reset_descent_dir: + descent_dir, _ = _bfgs_workhorse( + working_gradient, None, None, None, bfgs_B_inv, False + ) + else: + descent_dir, bfgs_B_inv = _bfgs_workhorse( + working_gradient, + old_gradient, + old_descent_dir, + linesearch_step, + bfgs_B_inv, + True, + ) + elif varipeps_config.optimizer_method is Optimizing_Methods.L_BFGS: + l_bfgs_x_cache.appendleft(tuple(working_tensors)) + l_bfgs_grad_cache.appendleft(tuple(working_gradient)) + + if count == 0 or signal_reset_descent_dir: + descent_dir = [-elem for elem in working_gradient] + else: + descent_dir = _l_bfgs_workhorse( + tuple(l_bfgs_x_cache), tuple(l_bfgs_grad_cache) + ) + else: + raise ValueError("Unknown optimization method.") + + signal_reset_descent_dir = False + + if _scalar_descent_grad(descent_dir, working_gradient) > 0: + logger.warning("Found bad descent dir. Reset to negative gradient!") + descent_dir = [-elem for elem in working_gradient] + + conv = jnp.linalg.norm(ravel_pytree(working_gradient)[0]) + step_conv[random_noise_retries].append(conv) + + try: + ( + working_tensors, + working_unitcell, + working_value, + linesearch_step, + signal_reset_descent_dir, + max_trunc_error, + ) = line_search( + working_tensors, + working_unitcell, + expectation_func, + working_gradient, + descent_dir, + working_value, + linesearch_step, + convert_to_unitcell_func, + generate_unitcell, + spiral_indices, + additional_input, + conv > varipeps_config.optimizer_reuse_env_eps, + ) + except NoSuitableStepSizeError: + runtime = time.perf_counter() - runtime_start + step_runtime[random_noise_retries].append(runtime) + + if varipeps_config.optimizer_fail_if_no_step_size_found: + raise + else: + if ( ( - working_value, - (working_unitcell, _), - ), working_gradient_seq = calc_preconverged_ctmrg_value_and_grad( - working_tensors, - working_unitcell, - expectation_func, - convert_to_unitcell_func, - additional_input, - calc_preconverged=(count == 0), + conv > varipeps_config.optimizer_random_noise_eps + or working_value > best_value ) - except (CTMRGNotConvergedError, CTMRGGradientNotConvergedError) as e: - varipeps_global_state.ctmrg_projector_method = None - - if random_noise_retries == 0: - return OptimizeResult( - success=False, - message=str(type(e)), - x=working_tensors, - fun=working_value, - unitcell=working_unitcell, - nit=count, - max_trunc_error_list=max_trunc_error_list, - step_energies=step_energies, - step_chi=step_chi, - step_conv=step_conv, - step_runtime=step_runtime, - best_run=0, + and random_noise_retries + < varipeps_config.optimizer_random_noise_max_retries + and not ( + varipeps_config.optimizer_preconverge_with_half_projectors + and not varipeps_global_state.basinhopping_disable_half_projector + and varipeps_global_state.ctmrg_projector_method + is Projector_Method.HALF ) - elif ( - random_noise_retries - >= varipeps_config.optimizer_random_noise_max_retries ): - working_value = jnp.inf - break - else: + logger.warning( + "⚠️ Convergence is not sufficient. Retry with some random noise on best result. 🔀" + ) + + if working_value < best_value: + best_value = working_value + best_tensors = working_tensors + best_unitcell = working_unitcell + best_run = random_noise_retries + + _autosave_wrapper( + autosave_func, + autosave_filename, + working_tensors, + working_unitcell, + working_value, + "best", + best_run, + max_trunc_error_list, + step_energies, + step_chi, + step_conv, + step_runtime, + spiral_indices, + additional_input, + ) + if isinstance(input_tensors, PEPS_Unit_Cell) or ( isinstance(input_tensors, collections.abc.Sequence) and isinstance(input_tensors[0], PEPS_Unit_Cell) @@ -665,7 +852,10 @@ def random_noise(a): working_tensors = ( cast( List[jnp.ndarray], - [i.tensor for i in best_unitcell.get_unique_tensors()], + [ + i.tensor + for i in best_unitcell.get_unique_tensors() + ], ) + best_tensors[best_unitcell.get_len_unique_tensors() :] ) @@ -674,7 +864,9 @@ def random_noise(a): working_tensors_obj = [ e.replace_tensor(working_tensors[i]) - for i, e in enumerate(best_unitcell.get_unique_tensors()) + for i, e in enumerate( + best_unitcell.get_unique_tensors() + ) ] working_unitcell = best_unitcell.replace_unique_tensors( @@ -698,350 +890,224 @@ def random_noise(a): max_trunc_error_list[random_noise_retries] = [] step_runtime[random_noise_retries] = [] - pbar.reset() - pbar.refresh() - - continue - - if working_unitcell[0, 0][0][0].chi != chi_before_ctmrg: - jax.clear_caches() - - working_gradient = [elem.conj() for elem in working_gradient_seq] - - if signal_reset_descent_dir: - if varipeps_config.optimizer_method is Optimizing_Methods.BFGS: - bfgs_prefactor = ( - 2 if any(jnp.iscomplexobj(t) for t in working_tensors) else 1 - ) - bfgs_B_inv = jnp.eye( - bfgs_prefactor * sum([t.size for t in working_tensors]) - ) - elif varipeps_config.optimizer_method is Optimizing_Methods.L_BFGS: - l_bfgs_x_cache = deque( - maxlen=varipeps_config.optimizer_l_bfgs_maxlen + 1 - ) - l_bfgs_grad_cache = deque( - maxlen=varipeps_config.optimizer_l_bfgs_maxlen + 1 - ) + if autosave_func is autosave_function: + descent_method_tuple = None + if ( + varipeps_config.optimizer_method + is Optimizing_Methods.BFGS + ): + descent_method_tuple = (bfgs_prefactor, bfgs_B_inv) + elif ( + varipeps_config.optimizer_method + is Optimizing_Methods.L_BFGS + ): + descent_method_tuple = ( + l_bfgs_x_cache, + l_bfgs_grad_cache, + ) + _autosave_wrapper( + partial( + autosave_function_restartable, + expectation_func=expectation_func, + convert_to_unitcell_func=convert_to_unitcell_func, + old_gradient=old_gradient, + old_descent_dir=old_descent_dir, + best_value=best_value, + best_tensors=best_tensors, + best_unitcell=best_unitcell, + random_noise_retries=random_noise_retries, + descent_method_tuple=descent_method_tuple, + count=count, + linesearch_step=linesearch_step, + projector_method=( + "HALF" + if varipeps_global_state.ctmrg_projector_method + is Projector_Method.HALF + else "FULL" + ), + signal_reset_descent_dir=signal_reset_descent_dir, + ), + autosave_filename, + working_tensors, + working_unitcell, + working_value, + None, + best_run, + max_trunc_error_list, + step_energies, + step_chi, + step_conv, + step_runtime, + spiral_indices, + additional_input, + ) - if varipeps_config.optimizer_method is Optimizing_Methods.STEEPEST: - descent_dir = [-elem for elem in working_gradient] - elif varipeps_config.optimizer_method is Optimizing_Methods.CG: - if count == 0 or signal_reset_descent_dir: - descent_dir = [-elem for elem in working_gradient] - else: - descent_dir, beta = _cg_workhorse( - working_gradient, old_gradient, old_descent_dir - ) - elif varipeps_config.optimizer_method is Optimizing_Methods.BFGS: - if count == 0 or signal_reset_descent_dir: - descent_dir, _ = _bfgs_workhorse( - working_gradient, None, None, None, bfgs_B_inv, False - ) - else: - descent_dir, bfgs_B_inv = _bfgs_workhorse( - working_gradient, - old_gradient, - old_descent_dir, - linesearch_step, - bfgs_B_inv, - True, - ) - elif varipeps_config.optimizer_method is Optimizing_Methods.L_BFGS: - l_bfgs_x_cache.appendleft(tuple(working_tensors)) - l_bfgs_grad_cache.appendleft(tuple(working_gradient)) - if count == 0 or signal_reset_descent_dir: - descent_dir = [-elem for elem in working_gradient] + continue else: - descent_dir = _l_bfgs_workhorse( - tuple(l_bfgs_x_cache), tuple(l_bfgs_grad_cache) - ) - else: - raise ValueError("Unknown optimization method.") - - signal_reset_descent_dir = False - - if _scalar_descent_grad(descent_dir, working_gradient) > 0: - tqdm.write("Found bad descent dir. Reset to negative gradient!") - descent_dir = [-elem for elem in working_gradient] + conv = 0 + else: + runtime = time.perf_counter() - runtime_start + step_runtime[random_noise_retries].append(runtime) + max_trunc_error_list[random_noise_retries].append(max_trunc_error) + step_energies[random_noise_retries].append(working_value) + step_chi[random_noise_retries].append( + working_unitcell.get_unique_tensors()[0].chi + ) - conv = jnp.linalg.norm(ravel_pytree(working_gradient)[0]) - step_conv[random_noise_retries].append(conv) + if ( + varipeps_config.optimizer_preconverge_with_half_projectors + and not varipeps_global_state.basinhopping_disable_half_projector + and varipeps_global_state.ctmrg_projector_method + is Projector_Method.HALF + and conv + < varipeps_config.optimizer_preconverge_with_half_projectors_eps + ): + varipeps_global_state.ctmrg_projector_method = ( + varipeps_config.ctmrg_full_projector_method + ) - try: - ( - working_tensors, - working_unitcell, - working_value, - linesearch_step, - signal_reset_descent_dir, - max_trunc_error, - ) = line_search( + working_value, (working_unitcell, max_trunc_error) = ( + calc_ctmrg_expectation( working_tensors, working_unitcell, expectation_func, - working_gradient, - descent_dir, - working_value, - linesearch_step, convert_to_unitcell_func, - generate_unitcell, - spiral_indices, additional_input, - conv > varipeps_config.optimizer_reuse_env_eps, + enforce_elementwise_convergence=varipeps_config.ad_use_custom_vjp, ) - except NoSuitableStepSizeError: - runtime = time.perf_counter() - runtime_start - step_runtime[random_noise_retries].append(runtime) + ) + descent_dir = None + working_gradient = None + signal_reset_descent_dir = True + conv = jnp.inf + linesearch_step = None + + if conv < varipeps_config.optimizer_convergence_eps: + working_value, ( + working_unitcell, + max_trunc_error, + ) = calc_ctmrg_expectation( + working_tensors, + working_unitcell, + expectation_func, + convert_to_unitcell_func, + additional_input, + enforce_elementwise_convergence=varipeps_config.ad_use_custom_vjp, + ) - if varipeps_config.optimizer_fail_if_no_step_size_found: - raise - else: - if ( - ( - conv > varipeps_config.optimizer_random_noise_eps - or working_value > best_value - ) - and random_noise_retries - < varipeps_config.optimizer_random_noise_max_retries - and not ( - varipeps_config.optimizer_preconverge_with_half_projectors - and not varipeps_global_state.basinhopping_disable_half_projector - and varipeps_global_state.ctmrg_projector_method - is Projector_Method.HALF - ) - ): - tqdm.write( - "Convergence is not sufficient. Retry with some random noise on best result." - ) + try: + max_trunc_error_list[random_noise_retries][-1] = max_trunc_error + except IndexError: + max_trunc_error_list[random_noise_retries].append(max_trunc_error) - if working_value < best_value: - best_value = working_value - best_tensors = working_tensors - best_unitcell = working_unitcell - best_run = random_noise_retries - - _autosave_wrapper( - autosave_func, - autosave_filename, - working_tensors, - working_unitcell, - working_value, - "best", - best_run, - max_trunc_error_list, - step_energies, - step_chi, - step_conv, - step_runtime, - spiral_indices, - additional_input, - ) + try: + step_energies[random_noise_retries][-1] = working_value + except IndexError: + step_energies[random_noise_retries].append(working_value) - if isinstance(input_tensors, PEPS_Unit_Cell) or ( - isinstance(input_tensors, collections.abc.Sequence) - and isinstance(input_tensors[0], PEPS_Unit_Cell) - ): - working_tensors = ( - cast( - List[jnp.ndarray], - [ - i.tensor - for i in best_unitcell.get_unique_tensors() - ], - ) - + best_tensors[best_unitcell.get_len_unique_tensors() :] - ) + try: + step_chi[random_noise_retries][ + -1 + ] = working_unitcell.get_unique_tensors()[0].chi + except IndexError: + step_chi[random_noise_retries].append( + working_unitcell.get_unique_tensors()[0].chi + ) - working_tensors = [random_noise(i) for i in working_tensors] + break - working_tensors_obj = [ - e.replace_tensor(working_tensors[i]) - for i, e in enumerate( - best_unitcell.get_unique_tensors() - ) - ] + old_descent_dir = descent_dir + old_gradient = working_gradient - working_unitcell = best_unitcell.replace_unique_tensors( - working_tensors_obj - ) - else: - working_tensors = [random_noise(i) for i in best_tensors] - working_unitcell = None - - descent_dir = None - working_gradient = None - signal_reset_descent_dir = True - count = 0 - random_noise_retries += 1 - old_descent_dir = descent_dir - old_gradient = working_gradient - - step_energies[random_noise_retries] = [] - step_chi[random_noise_retries] = [] - step_conv[random_noise_retries] = [] - max_trunc_error_list[random_noise_retries] = [] - step_runtime[random_noise_retries] = [] - - if autosave_func is autosave_function: - descent_method_tuple = None - if ( - varipeps_config.optimizer_method - is Optimizing_Methods.BFGS - ): - descent_method_tuple = (bfgs_prefactor, bfgs_B_inv) - elif ( - varipeps_config.optimizer_method - is Optimizing_Methods.L_BFGS - ): - descent_method_tuple = ( - l_bfgs_x_cache, - l_bfgs_grad_cache, - ) - _autosave_wrapper( - partial( - autosave_function_restartable, - expectation_func=expectation_func, - convert_to_unitcell_func=convert_to_unitcell_func, - old_gradient=old_gradient, - old_descent_dir=old_descent_dir, - best_value=best_value, - best_tensors=best_tensors, - best_unitcell=best_unitcell, - random_noise_retries=random_noise_retries, - descent_method_tuple=descent_method_tuple, - count=count, - linesearch_step=linesearch_step, - projector_method=( - "HALF" - if varipeps_global_state.ctmrg_projector_method - is Projector_Method.HALF - else "FULL" - ), - signal_reset_descent_dir=signal_reset_descent_dir, - ), - autosave_filename, - working_tensors, - working_unitcell, - working_value, - None, - best_run, - max_trunc_error_list, - step_energies, - step_chi, - step_conv, - step_runtime, - spiral_indices, - additional_input, - ) + count += 1 - pbar.reset() - pbar.refresh() + logger.info( + "📉 %4d | E=%.8f ΔE=%+.2e | r=%d | ‖∇ψ‖=%.2e | ε_tr=%.1e | χ=%d | t=%.0fs", + int(count), + float(working_value), + -float(working_value - step_energies[random_noise_retries][-2]) if len(step_energies[random_noise_retries]) > 1 else 0.0, + int(random_noise_retries), + float(conv), + float(max_trunc_error), + int(step_chi[random_noise_retries][-1]), + float(runtime), + ) - continue - else: - conv = 0 - else: - runtime = time.perf_counter() - runtime_start - step_runtime[random_noise_retries].append(runtime) - max_trunc_error_list[random_noise_retries].append(max_trunc_error) - step_energies[random_noise_retries].append(working_value) - step_chi[random_noise_retries].append( - working_unitcell.get_unique_tensors()[0].chi - ) + if count % varipeps_config.optimizer_autosave_step_count == 0: + _autosave_wrapper( + autosave_func, + autosave_filename, + working_tensors, + working_unitcell, + working_value, + random_noise_retries, + best_run, + max_trunc_error_list, + step_energies, + step_chi, + step_conv, + step_runtime, + spiral_indices, + additional_input, + ) - if ( + if working_value < best_value and not ( varipeps_config.optimizer_preconverge_with_half_projectors and not varipeps_global_state.basinhopping_disable_half_projector and varipeps_global_state.ctmrg_projector_method is Projector_Method.HALF - and conv - < varipeps_config.optimizer_preconverge_with_half_projectors_eps ): - varipeps_global_state.ctmrg_projector_method = ( - varipeps_config.ctmrg_full_projector_method - ) - - working_value, (working_unitcell, max_trunc_error) = ( - calc_ctmrg_expectation( - working_tensors, - working_unitcell, - expectation_func, - convert_to_unitcell_func, - additional_input, - enforce_elementwise_convergence=varipeps_config.ad_use_custom_vjp, - ) - ) - descent_dir = None - working_gradient = None - signal_reset_descent_dir = True - conv = jnp.inf - linesearch_step = None - - if conv < varipeps_config.optimizer_convergence_eps: - working_value, ( - working_unitcell, - max_trunc_error, - ) = calc_ctmrg_expectation( + _autosave_wrapper( + autosave_func, + autosave_filename, working_tensors, working_unitcell, - expectation_func, - convert_to_unitcell_func, + working_value, + "best", + random_noise_retries, + max_trunc_error_list, + step_energies, + step_chi, + step_conv, + step_runtime, + spiral_indices, additional_input, - enforce_elementwise_convergence=varipeps_config.ad_use_custom_vjp, ) - try: - max_trunc_error_list[random_noise_retries][-1] = max_trunc_error - except IndexError: - max_trunc_error_list[random_noise_retries].append(max_trunc_error) - - try: - step_energies[random_noise_retries][-1] = working_value - except IndexError: - step_energies[random_noise_retries].append(working_value) - - try: - step_chi[random_noise_retries][ - -1 - ] = working_unitcell.get_unique_tensors()[0].chi - except IndexError: - step_chi[random_noise_retries].append( - working_unitcell.get_unique_tensors()[0].chi - ) - - break - - old_descent_dir = descent_dir - old_gradient = working_gradient - - count += 1 - - pbar.update() - pbar.set_postfix( - { - "Energy": f"{working_value:0.10f}", - "Retries": random_noise_retries, - "Convergence": f"{conv:0.8f}", - "Line search step": ( - f"{linesearch_step:0.8f}" - if linesearch_step is not None - else "0" - ), - "Max. trunc. err.": f"{max_trunc_error:0.8g}", - } - ) - pbar.refresh() - - if count % varipeps_config.optimizer_autosave_step_count == 0: + if autosave_func is autosave_function: + descent_method_tuple = None + if varipeps_config.optimizer_method is Optimizing_Methods.BFGS: + descent_method_tuple = (bfgs_prefactor, bfgs_B_inv) + elif varipeps_config.optimizer_method is Optimizing_Methods.L_BFGS: + descent_method_tuple = (l_bfgs_x_cache, l_bfgs_grad_cache) _autosave_wrapper( - autosave_func, + partial( + autosave_function_restartable, + expectation_func=expectation_func, + convert_to_unitcell_func=convert_to_unitcell_func, + old_gradient=old_gradient, + old_descent_dir=old_descent_dir, + best_value=best_value, + best_tensors=best_tensors, + best_unitcell=best_unitcell, + random_noise_retries=random_noise_retries, + descent_method_tuple=descent_method_tuple, + count=count, + linesearch_step=linesearch_step, + projector_method=( + "HALF" + if varipeps_global_state.ctmrg_projector_method + is Projector_Method.HALF + else "FULL" + ), + signal_reset_descent_dir=signal_reset_descent_dir, + ), autosave_filename, working_tensors, working_unitcell, working_value, - random_noise_retries, + None, best_run, max_trunc_error_list, step_energies, @@ -1052,151 +1118,81 @@ def random_noise(a): additional_input, ) - if working_value < best_value and not ( - varipeps_config.optimizer_preconverge_with_half_projectors - and not varipeps_global_state.basinhopping_disable_half_projector - and varipeps_global_state.ctmrg_projector_method - is Projector_Method.HALF - ): - _autosave_wrapper( - autosave_func, - autosave_filename, - working_tensors, - working_unitcell, - working_value, - "best", - random_noise_retries, - max_trunc_error_list, - step_energies, - step_chi, - step_conv, - step_runtime, - spiral_indices, - additional_input, - ) - - if autosave_func is autosave_function: - descent_method_tuple = None - if varipeps_config.optimizer_method is Optimizing_Methods.BFGS: - descent_method_tuple = (bfgs_prefactor, bfgs_B_inv) - elif varipeps_config.optimizer_method is Optimizing_Methods.L_BFGS: - descent_method_tuple = (l_bfgs_x_cache, l_bfgs_grad_cache) - _autosave_wrapper( - partial( - autosave_function_restartable, - expectation_func=expectation_func, - convert_to_unitcell_func=convert_to_unitcell_func, - old_gradient=old_gradient, - old_descent_dir=old_descent_dir, - best_value=best_value, - best_tensors=best_tensors, - best_unitcell=best_unitcell, - random_noise_retries=random_noise_retries, - descent_method_tuple=descent_method_tuple, - count=count, - linesearch_step=linesearch_step, - projector_method=( - "HALF" - if varipeps_global_state.ctmrg_projector_method - is Projector_Method.HALF - else "FULL" - ), - signal_reset_descent_dir=signal_reset_descent_dir, - ), - autosave_filename, - working_tensors, - working_unitcell, - working_value, - None, - best_run, - max_trunc_error_list, - step_energies, - step_chi, - step_conv, - step_runtime, - spiral_indices, - additional_input, - ) + if working_value < best_value and not ( + varipeps_config.optimizer_preconverge_with_half_projectors + and not varipeps_global_state.basinhopping_disable_half_projector + and varipeps_global_state.ctmrg_projector_method + is Projector_Method.HALF + ): + best_value = working_value + best_tensors = working_tensors + best_unitcell = working_unitcell + best_run = random_noise_retries + + if ( + varipeps_config.slurm_restart_mode is not Slurm_Restart_Mode.DISABLED + and (slurm_data := SlurmUtils.get_own_job_data()) is not None + ): + flatten_runtime = [j for i in step_runtime for j in step_runtime[i]] + runtime_mean = np.mean(flatten_runtime) + runtime_std = np.std(flatten_runtime) - if working_value < best_value and not ( - varipeps_config.optimizer_preconverge_with_half_projectors - and not varipeps_global_state.basinhopping_disable_half_projector - and varipeps_global_state.ctmrg_projector_method - is Projector_Method.HALF - ): - best_value = working_value - best_tensors = working_tensors - best_unitcell = working_unitcell - best_run = random_noise_retries + remaining_slurm_time = slurm_data["TimeLimit"] - slurm_data["RunTime"] if ( - varipeps_config.slurm_restart_mode is not Slurm_Restart_Mode.DISABLED - and (slurm_data := SlurmUtils.get_own_job_data()) is not None - ): - flatten_runtime = [j for i in step_runtime for j in step_runtime[i]] - runtime_mean = np.mean(flatten_runtime) - runtime_std = np.std(flatten_runtime) + remaining_time_correction := os.environ.get( + "VARIPEPS_REMAINING_TIME_CORRECTION" + ) + ) is not None: + try: + remaining_time_correction = int(remaining_time_correction) + remaining_slurm_time -= datetime.timedelta( + seconds=remaining_time_correction + ) + except (TypeError, ValueError): + pass - remaining_slurm_time = slurm_data["TimeLimit"] - slurm_data["RunTime"] + time_of_one_step = datetime.timedelta( + seconds=runtime_mean + 3 * runtime_std + ) + if remaining_slurm_time < time_of_one_step: + logger.info("⏳ Average time of optimizer step below remaining Slurm runtime") if ( - remaining_time_correction := os.environ.get( - "VARIPEPS_REMAINING_TIME_CORRECTION" + restart_needed_filename := os.environ.get( + "VARIPEPS_NEED_RESTART_FILE" ) ) is not None: - try: - remaining_time_correction = int(remaining_time_correction) - remaining_slurm_time -= datetime.timedelta( - seconds=remaining_time_correction - ) - except (TypeError, ValueError): - pass - - time_of_one_step = datetime.timedelta( - seconds=runtime_mean + 3 * runtime_std - ) + pathlib.Path(restart_needed_filename).touch() - if remaining_slurm_time < time_of_one_step: - print( - "Average time of optimizer step below remaining Slurm runtime", - file=sys.stderr, + if ( + varipeps_config.slurm_restart_mode + is Slurm_Restart_Mode.WRITE_RESTART_SCRIPT + or varipeps_config.slurm_restart_mode + is Slurm_Restart_Mode.AUTOMATIC_RESTART + ): + SlurmUtils.generate_restart_scripts( + f"{str(autosave_filename)}.restart.slurm", + f"{str(autosave_filename)}.restart.py", + f"{str(autosave_filename)}.restartable", + slurm_data, ) - if ( - restart_needed_filename := os.environ.get( - "VARIPEPS_NEED_RESTART_FILE" - ) - ) is not None: - pathlib.Path(restart_needed_filename).touch() - - if ( - varipeps_config.slurm_restart_mode - is Slurm_Restart_Mode.WRITE_RESTART_SCRIPT - or varipeps_config.slurm_restart_mode - is Slurm_Restart_Mode.AUTOMATIC_RESTART - ): - SlurmUtils.generate_restart_scripts( - f"{str(autosave_filename)}.restart.slurm", - f"{str(autosave_filename)}.restart.py", - f"{str(autosave_filename)}.restartable", - slurm_data, - ) - - slurm_restart_written = True + slurm_restart_written = True - if ( - varipeps_config.slurm_restart_mode - is Slurm_Restart_Mode.AUTOMATIC_RESTART - ): - slurm_new_job_id = SlurmUtils.run_slurm_script( - f"{str(autosave_filename)}.restart.slurm", - slurm_data["WorkDir"], + if ( + varipeps_config.slurm_restart_mode + is Slurm_Restart_Mode.AUTOMATIC_RESTART + ): + slurm_new_job_id = SlurmUtils.run_slurm_script( + f"{str(autosave_filename)}.restart.slurm", + slurm_data["WorkDir"], + ) + if slurm_new_job_id is None: + logger.error( + "Failed to start new Slurm job or parse its job id." ) - if slurm_new_job_id is None: - tqdm.write( - "Failed to start new Slurm job or parse its job id." - ) - break + break if working_value < best_value: best_value = working_value diff --git a/varipeps/utils/__init__.py b/varipeps/utils/__init__.py index 5609f81..049cf27 100644 --- a/varipeps/utils/__init__.py +++ b/varipeps/utils/__init__.py @@ -4,3 +4,4 @@ from . import projector_dict from . import slurm from . import svd +from . import logging_config diff --git a/varipeps/utils/debug_print.py b/varipeps/utils/debug_print.py index f7f02d4..759fb9a 100644 --- a/varipeps/utils/debug_print.py +++ b/varipeps/utils/debug_print.py @@ -1,22 +1,23 @@ import functools +import logging -from tqdm_loggable.auto import tqdm import jax.debug as jdebug from jax._src.debugging import formatter -# Adapting function from jax.debug to work with tqdm +logger = logging.getLogger("varipeps.ctmrg") def _format_print_callback(fmt: str, *args, **kwargs): - tqdm.write(fmt.format(*args, **kwargs)) + # Send to logger (respects per-module levels/handlers) + logger.debug(fmt.format(*args, **kwargs)) def debug_print(fmt: str, *args, ordered: bool = True, **kwargs) -> None: """ Prints values and works in staged out JAX functions. - Function adapted from :obj:`jax.debug.print` to work with tqdm. See there + Function adapted from :obj:`jax.debug.print` to work with logger. See there for original authors and function. Args: diff --git a/varipeps/utils/logging_config.py b/varipeps/utils/logging_config.py new file mode 100644 index 0000000..32310f4 --- /dev/null +++ b/varipeps/utils/logging_config.py @@ -0,0 +1,152 @@ +from __future__ import annotations + +import logging +from typing import Any + +from varipeps import config as _cfg_mod # uses the global config instance + +# --- Custom tqdm-based handlers --- + +class TqdmUpdateHandler(logging.Handler): + """Updates a tqdm progress bar's postfix string instead of printing.""" + def __init__(self, pbar: Any): + super().__init__() + self.pbar = pbar + + def emit(self, record: logging.LogRecord) -> None: + try: + msg = self.format(record) + # Truncate to keep the bar compact + self.pbar.set_postfix_str(str(msg), refresh=True) + except Exception: # nosec - logging must never raise + self.handleError(record) + + +class TqdmWriteHandler(logging.Handler): + """Writes messages via tqdm.write (thread-safe).""" + def emit(self, record: logging.LogRecord) -> None: + try: + msg = self.format(record) + from tqdm import tqdm + tqdm.write(str(msg)) + except Exception: + self.handleError(record) + + +class ExcludeLoggerFilter(logging.Filter): + """Exclude records whose logger name starts with a given prefix.""" + def __init__(self, prefix: str): + super().__init__() + self.prefix = prefix + + def filter(self, record: logging.LogRecord) -> bool: + return not record.name.startswith(self.prefix) + +_LOGGING_INITIALIZED = False + +def _to_py_log_level(level: Any) -> int: + # Accept both enum values and raw ints; OFF disables effectively + try: + val = int(level) + except Exception: + val = logging.INFO + if val == 0: # OFF + return logging.CRITICAL + 10 + return val + +def init_logging(cfg: Any | None = None) -> None: + """ + Initialize logging based on the provided config (or global config). + Safe to call multiple times; replaces handlers to avoid duplicates. + """ + global _LOGGING_INITIALIZED + if cfg is None: + cfg = _cfg_mod.config + + root = logging.getLogger("varipeps") + # Remove old handlers to prevent duplicate logs + for h in list(root.handlers): + root.removeHandler(h) + + root.setLevel(_to_py_log_level(getattr(cfg, "log_level_global", logging.INFO))) + root.propagate = False + + fmt = logging.Formatter( + fmt="%(asctime)s %(levelname)s %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + + use_tqdm = bool(getattr(cfg, "log_tqdm", False)) + + if use_tqdm: + fmt = logging.Formatter(fmt="%(message)s") + # Console via tqdm.write for all varipeps loggers except optimizer + tw = TqdmWriteHandler() + tw.setFormatter(fmt) + tw.addFilter(ExcludeLoggerFilter("varipeps.optimizer")) + root.addHandler(tw) + + # Preserve file logging if enabled + if getattr(cfg, "log_to_file", False): + fh = logging.FileHandler(getattr(cfg, "log_file", "varipeps.log")) + fh.setFormatter(fmt) + root.addHandler(fh) + + # Optimizer uses a tqdm progress bar update handler + opt_logger = logging.getLogger("varipeps.optimizer") + for h in list(opt_logger.handlers): + opt_logger.removeHandler(h) + + from tqdm import tqdm + # Create a lightweight bar that we only update the postfix for + pbar = tqdm(total=0, position=0, leave=True, dynamic_ncols=True) + + if pbar is not None: + th = TqdmUpdateHandler(pbar) + th.setFormatter(fmt) + opt_logger.addHandler(th) + + # Keep propagation so optimizer still logs to file handler if present, + # while console is suppressed by the ExcludeLoggerFilter on root. + opt_logger.propagate = True + else: + # Standard console/file logging + if getattr(cfg, "log_to_console", True): + sh = logging.StreamHandler() + sh.setFormatter(fmt) + root.addHandler(sh) + + if getattr(cfg, "log_to_file", False): + fh = logging.FileHandler(getattr(cfg, "log_file", "varipeps.log")) + fh.setFormatter(fmt) + root.addHandler(fh) + + # Ensure optimizer has no leftover tqdm handler from a previous init + opt_logger = logging.getLogger("varipeps.optimizer") + for h in list(opt_logger.handlers): + opt_logger.removeHandler(h) + opt_logger.propagate = True + + # Per-module levels + logging.getLogger("varipeps.optimizer").setLevel( + _to_py_log_level(getattr(cfg, "log_level_optimizer", logging.INFO)) + ) + logging.getLogger("varipeps.ctmrg").setLevel( + _to_py_log_level(getattr(cfg, "log_level_ctmrg", logging.INFO)) + ) + logging.getLogger("varipeps.line_search").setLevel( + _to_py_log_level(getattr(cfg, "log_level_line_search", logging.INFO)) + ) + logging.getLogger("varipeps.expectation").setLevel( + _to_py_log_level(getattr(cfg, "log_level_expectation", logging.INFO)) + ) + + _LOGGING_INITIALIZED = True + +def ensure_logging_configured(cfg: Any | None = None) -> None: + """ + Initialize logging once on first call; subsequent calls are no-ops. + """ + global _LOGGING_INITIALIZED + if not _LOGGING_INITIALIZED: + init_logging(cfg)