-
Notifications
You must be signed in to change notification settings - Fork 1
Replace tqdm with a logger for different modules. #3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
7da0031
2dcbf44
69cd860
6489f3c
27a3f12
413615c
a405978
929a975
912c3dc
18cbc3e
424c07d
7fb3f45
cae78c4
b399c59
9d60993
eda0237
473de9b
21e26e8
4bc445c
829cde1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,6 +5,11 @@ | |
| from jax import jit, custom_vjp, vjp, tree_util | ||
| from jax.lax import cond, while_loop | ||
| import jax.debug as jdebug | ||
| import logging | ||
| import time | ||
| import jax | ||
|
|
||
| logger = logging.getLogger("varipeps.ctmrg") | ||
|
|
||
| from varipeps import varipeps_config, varipeps_global_state | ||
| from varipeps.peps import PEPS_Tensor, PEPS_Tensor_Split_Transfer, PEPS_Unit_Cell | ||
|
|
@@ -515,9 +520,8 @@ def corner_svd_func(old, new, old_corner, conv_eps, config): | |
| eps, | ||
| config, | ||
| ) | ||
|
|
||
| if config.ctmrg_print_steps: | ||
| debug_print("CTMRG: {}: {}", count, measure) | ||
| if logger.isEnabledFor(logging.DEBUG): | ||
| jax.debug.callback(lambda cnt, msr: logger.debug(f"CTMRG: Step {cnt}: {msr}"), count, measure, ordered=True) | ||
| if config.ctmrg_verbose_output: | ||
| jax.debug.callback(print_verbose, verbose_data, ordered=True) | ||
|
|
||
|
|
@@ -620,9 +624,9 @@ def calc_ctmrg_env( | |
| best_norm_smallest_S = None | ||
| best_truncation_eps = None | ||
| have_been_increased = False | ||
|
|
||
| while True: | ||
| tmp_count = 0 | ||
| t0 = time.perf_counter() | ||
| corner_singular_vals = None | ||
|
|
||
| while tmp_count < varipeps_config.ctmrg_max_steps and ( | ||
|
|
@@ -720,6 +724,17 @@ def calc_ctmrg_env( | |
| else: | ||
| converged = False | ||
| end_count = tmp_count | ||
|
|
||
| if not converged and logger.isEnabledFor(logging.WARNING): | ||
| logger.warning( | ||
| "CTMRG: ❌ did not converge, took %.2f seconds. (Steps: %d, Smallest SVD Norm: %.3e)", | ||
| time.perf_counter() - t0, end_count, norm_smallest_S | ||
| ) | ||
| elif logger.isEnabledFor(logging.INFO): | ||
| logger.info( | ||
| "CTMRG: ✅ converged, took %.2f seconds. (Steps: %d, Smallest SVD Norm: %.3e)", | ||
| time.perf_counter() - t0, end_count, norm_smallest_S | ||
| ) | ||
|
|
||
| if converged and ( | ||
| working_unitcell[0, 0][0][0].chi > best_chi or best_result is None | ||
|
|
@@ -751,9 +766,9 @@ def calc_ctmrg_env( | |
| working_unitcell = working_unitcell.change_chi(new_chi) | ||
| initial_unitcell = initial_unitcell.change_chi(new_chi) | ||
|
|
||
| if varipeps_config.ctmrg_print_steps: | ||
| debug_print( | ||
| "CTMRG: Increasing chi to {} since smallest SVD Norm was {}.", | ||
| if logger.isEnabledFor(logging.WARNING): | ||
| logger.warning( | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is only a |
||
| "Increasing chi to %d since smallest SVD Norm was %.3e.", | ||
| new_chi, | ||
| norm_smallest_S, | ||
| ) | ||
|
|
@@ -785,9 +800,9 @@ def calc_ctmrg_env( | |
| if not new_chi in already_tried_chi: | ||
| working_unitcell = working_unitcell.change_chi(new_chi) | ||
|
|
||
| if varipeps_config.ctmrg_print_steps: | ||
| debug_print( | ||
| "CTMRG: Decreasing chi to {} since smallest SVD Norm was {} or routine did not converge.", | ||
| if logger.isEnabledFor(logging.WARNING): | ||
| logger.warning( | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is only a |
||
| "Decreasing chi to %d since smallest SVD Norm was %.3e or routine did not converge.", | ||
| new_chi, | ||
| norm_smallest_S, | ||
| ) | ||
|
|
@@ -809,9 +824,9 @@ def calc_ctmrg_env( | |
| new_truncation_eps | ||
| <= varipeps_config.ctmrg_increase_truncation_eps_max_value | ||
| ): | ||
| if varipeps_config.ctmrg_print_steps: | ||
| debug_print( | ||
| "CTMRG: Increasing SVD truncation eps to {}.", | ||
| if logger.isEnabledFor(logging.WARNING): | ||
| logger.warning( | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is only a |
||
| "Increasing SVD truncation eps to %.1e.", | ||
| new_truncation_eps, | ||
| ) | ||
| varipeps_global_state.ctmrg_effective_truncation_eps = ( | ||
|
|
@@ -884,6 +899,8 @@ def calc_ctmrg_env_fwd( | |
| Internal helper function of custom VJP to calculate the values in | ||
| the forward sweep. | ||
| """ | ||
| if logger.isEnabledFor(logging.DEBUG): | ||
| logger.debug("Custom VJP: Starting forward CTMRG calculation.") | ||
| new_unitcell, last_truncation_eps, norm_smallest_S = calc_ctmrg_env_custom_rule( | ||
| peps_tensors, unitcell, _return_truncation_eps=True | ||
| ) | ||
|
|
@@ -937,8 +954,8 @@ def _ctmrg_rev_while_body(carry): | |
|
|
||
| count += 1 | ||
|
|
||
| if config.ad_custom_print_steps: | ||
| debug_print("Custom VJP: {}: {}", count, measure) | ||
| if logger.isEnabledFor(logging.DEBUG): | ||
| jax.debug.callback(lambda cnt, msr: logger.debug(f"Custom VJP: Step {cnt}: {msr}"), count, measure, ordered=True) | ||
| if config.ad_custom_verbose_output: | ||
| jax.debug.callback(print_verbose, verbose_data, ordered=True, ad=True) | ||
|
|
||
|
|
@@ -1009,17 +1026,31 @@ def calc_ctmrg_env_rev( | |
| Internal helper function of custom VJP to calculate the gradient in | ||
| the backward sweep. | ||
| """ | ||
| if logger.isEnabledFor(logging.DEBUG): | ||
| logger.debug("Custom VJP: Starting reverse CTMRG calculation.") | ||
| unitcell_bar, _ = input_bar | ||
| peps_tensors, new_unitcell, input_unitcell, last_truncation_eps = res | ||
|
|
||
| varipeps_global_state.ctmrg_effective_truncation_eps = last_truncation_eps | ||
|
|
||
| if logger.isEnabledFor(logging.WARNING): | ||
| t0 = time.perf_counter() | ||
| t_bar, converged, end_count = _ctmrg_rev_workhorse( | ||
| peps_tensors, new_unitcell, unitcell_bar, varipeps_config, varipeps_global_state | ||
| ) | ||
|
|
||
| varipeps_global_state.ctmrg_effective_truncation_eps = None | ||
|
|
||
| if not converged and logger.isEnabledFor(logging.WARNING): | ||
| logger.warning( | ||
| "Custom VJP: ❌ did not converge, took %.2f seconds. (Steps: %d)", | ||
| time.perf_counter() - t0, end_count | ||
| ) | ||
| elif logger.isEnabledFor(logging.INFO): | ||
| logger.info( | ||
| "Custom VJP: ✅ converged, took %.2f seconds. (Steps: %d)", | ||
| time.perf_counter() - t0, end_count | ||
| ) | ||
| if end_count == varipeps_config.ad_custom_max_steps and not converged: | ||
| raise CTMRGGradientNotConvergedError | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add description of the config flag to the documentation.