Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions varipeps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,4 @@

jax_config.update("jax_enable_x64", True)

from tqdm_loggable.tqdm_logging import tqdm_logging
import datetime

tqdm_logging.set_log_rate(datetime.timedelta(seconds=60))

del datetime
del tqdm_logging
del jax_config
47 changes: 45 additions & 2 deletions varipeps/config.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from dataclasses import dataclass
from enum import Enum, IntEnum, auto, unique
from typing import TypeVar, Tuple, Any, Type, NoReturn
import logging

import numpy as np

from jax.tree_util import register_pytree_node_class

from typing import TypeVar, Tuple, Any, Type, NoReturn

T_VariPEPS_Config = TypeVar("T_VariPEPS_Config", bound="VariPEPS_Config")

Expand Down Expand Up @@ -54,6 +55,15 @@ class Slurm_Restart_Mode(IntEnum):
AUTOMATIC_RESTART = auto() #: Write restart script and start new slurm job with it


@unique
class LogLevel(IntEnum):
OFF = 0
ERROR = logging.ERROR
WARNING = logging.WARNING
INFO = logging.INFO
DEBUG = logging.DEBUG


@dataclass
@register_pytree_node_class
class VariPEPS_Config:
Expand Down Expand Up @@ -310,6 +320,17 @@ class VariPEPS_Config:
# Slurm
slurm_restart_mode: Slurm_Restart_Mode = Slurm_Restart_Mode.WRITE_NEED_RESTART_FILE

# Logging configuration
log_level_global: LogLevel = LogLevel.INFO
log_level_optimizer: LogLevel = LogLevel.INFO
log_level_ctmrg: LogLevel = LogLevel.INFO
log_level_line_search: LogLevel = LogLevel.INFO
log_level_expectation: LogLevel = LogLevel.INFO
log_to_console: bool = True
log_to_file: bool = False
log_file: str = "varipeps.log"
log_step_summary_every_n: int = 1
Comment on lines +323 to +332
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add description of the config flag to the documentation.


def update(self, name: str, value: Any) -> NoReturn:
self.__setattr__(name, value)

Expand Down Expand Up @@ -346,12 +367,33 @@ def __setattr__(self, name: str, value: Any) -> NoReturn:
elif (
field.type is bool
and hasattr(value, "dtype")
and np.isdtype(value.dtype, np.bool)
and np.issubdtype(value.dtype, np.bool_)
and value.size == 1
):
if value.ndim > 0:
value = value.reshape(-1)[0]
value = bool(value)
elif isinstance(field.type, type) and issubclass(field.type, Enum):
# Accept ints/np.int64 or enum names for Enum fields
if isinstance(value, field.type):
pass
elif isinstance(value, (int,)) or (
hasattr(value, "dtype")
and np.issubdtype(value.dtype, np.integer)
and value.size == 1
):
if hasattr(value, "ndim") and value.ndim > 0:
value = value.reshape(-1)[0]
value = field.type(int(value))
elif isinstance(value, str):
try:
value = field.type[value]
except KeyError:
value = field.type(int(value))
else:
raise TypeError(
f"Type mismatch for option '{name}', got '{type(value)}', expected '{field.type}'."
)
else:
raise TypeError(
f"Type mismatch for option '{name}', got '{type(value)}', expected '{field.type}'."
Expand Down Expand Up @@ -395,6 +437,7 @@ class ConfigModuleWrapper:
"Projector_Method",
"Wavevector_Type",
"Slurm_Restart_Mode",
"LogLevel",
"VariPEPS_Config",
"config",
}
Expand Down
61 changes: 46 additions & 15 deletions varipeps/ctmrg/routine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@
from jax import jit, custom_vjp, vjp, tree_util
from jax.lax import cond, while_loop
import jax.debug as jdebug
import logging
import time
import jax

logger = logging.getLogger("varipeps.ctmrg")

from varipeps import varipeps_config, varipeps_global_state
from varipeps.peps import PEPS_Tensor, PEPS_Tensor_Split_Transfer, PEPS_Unit_Cell
Expand Down Expand Up @@ -515,9 +520,8 @@ def corner_svd_func(old, new, old_corner, conv_eps, config):
eps,
config,
)

if config.ctmrg_print_steps:
debug_print("CTMRG: {}: {}", count, measure)
if logger.isEnabledFor(logging.DEBUG):
jax.debug.callback(lambda cnt, msr: logger.debug(f"CTMRG: Step {cnt}: {msr}"), count, measure, ordered=True)
if config.ctmrg_verbose_output:
jax.debug.callback(print_verbose, verbose_data, ordered=True)

Expand Down Expand Up @@ -620,9 +624,9 @@ def calc_ctmrg_env(
best_norm_smallest_S = None
best_truncation_eps = None
have_been_increased = False

while True:
tmp_count = 0
t0 = time.perf_counter()
corner_singular_vals = None

while tmp_count < varipeps_config.ctmrg_max_steps and (
Expand Down Expand Up @@ -720,6 +724,17 @@ def calc_ctmrg_env(
else:
converged = False
end_count = tmp_count

if not converged and logger.isEnabledFor(logging.WARNING):
logger.warning(
"CTMRG: ❌ did not converge, took %.2f seconds. (Steps: %d, Smallest SVD Norm: %.3e)",
time.perf_counter() - t0, end_count, norm_smallest_S
)
elif logger.isEnabledFor(logging.INFO):
logger.info(
"CTMRG: ✅ converged, took %.2f seconds. (Steps: %d, Smallest SVD Norm: %.3e)",
time.perf_counter() - t0, end_count, norm_smallest_S
)

if converged and (
working_unitcell[0, 0][0][0].chi > best_chi or best_result is None
Expand Down Expand Up @@ -751,9 +766,9 @@ def calc_ctmrg_env(
working_unitcell = working_unitcell.change_chi(new_chi)
initial_unitcell = initial_unitcell.change_chi(new_chi)

if varipeps_config.ctmrg_print_steps:
debug_print(
"CTMRG: Increasing chi to {} since smallest SVD Norm was {}.",
if logger.isEnabledFor(logging.WARNING):
logger.warning(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is only a info level notification not a warning. This is intended behavior.

"Increasing chi to %d since smallest SVD Norm was %.3e.",
new_chi,
norm_smallest_S,
)
Expand Down Expand Up @@ -785,9 +800,9 @@ def calc_ctmrg_env(
if not new_chi in already_tried_chi:
working_unitcell = working_unitcell.change_chi(new_chi)

if varipeps_config.ctmrg_print_steps:
debug_print(
"CTMRG: Decreasing chi to {} since smallest SVD Norm was {} or routine did not converge.",
if logger.isEnabledFor(logging.WARNING):
logger.warning(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is only a info level notification not a warning. This is intended behavior.

"Decreasing chi to %d since smallest SVD Norm was %.3e or routine did not converge.",
new_chi,
norm_smallest_S,
)
Expand All @@ -809,9 +824,9 @@ def calc_ctmrg_env(
new_truncation_eps
<= varipeps_config.ctmrg_increase_truncation_eps_max_value
):
if varipeps_config.ctmrg_print_steps:
debug_print(
"CTMRG: Increasing SVD truncation eps to {}.",
if logger.isEnabledFor(logging.WARNING):
logger.warning(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is only a info level notification not a warning. This is intended behavior.

"Increasing SVD truncation eps to %.1e.",
new_truncation_eps,
)
varipeps_global_state.ctmrg_effective_truncation_eps = (
Expand Down Expand Up @@ -884,6 +899,8 @@ def calc_ctmrg_env_fwd(
Internal helper function of custom VJP to calculate the values in
the forward sweep.
"""
if logger.isEnabledFor(logging.DEBUG):
logger.debug("Custom VJP: Starting forward CTMRG calculation.")
new_unitcell, last_truncation_eps, norm_smallest_S = calc_ctmrg_env_custom_rule(
peps_tensors, unitcell, _return_truncation_eps=True
)
Expand Down Expand Up @@ -937,8 +954,8 @@ def _ctmrg_rev_while_body(carry):

count += 1

if config.ad_custom_print_steps:
debug_print("Custom VJP: {}: {}", count, measure)
if logger.isEnabledFor(logging.DEBUG):
jax.debug.callback(lambda cnt, msr: logger.debug(f"Custom VJP: Step {cnt}: {msr}"), count, measure, ordered=True)
if config.ad_custom_verbose_output:
jax.debug.callback(print_verbose, verbose_data, ordered=True, ad=True)

Expand Down Expand Up @@ -1009,17 +1026,31 @@ def calc_ctmrg_env_rev(
Internal helper function of custom VJP to calculate the gradient in
the backward sweep.
"""
if logger.isEnabledFor(logging.DEBUG):
logger.debug("Custom VJP: Starting reverse CTMRG calculation.")
unitcell_bar, _ = input_bar
peps_tensors, new_unitcell, input_unitcell, last_truncation_eps = res

varipeps_global_state.ctmrg_effective_truncation_eps = last_truncation_eps

if logger.isEnabledFor(logging.WARNING):
t0 = time.perf_counter()
t_bar, converged, end_count = _ctmrg_rev_workhorse(
peps_tensors, new_unitcell, unitcell_bar, varipeps_config, varipeps_global_state
)

varipeps_global_state.ctmrg_effective_truncation_eps = None

if not converged and logger.isEnabledFor(logging.WARNING):
logger.warning(
"Custom VJP: ❌ did not converge, took %.2f seconds. (Steps: %d)",
time.perf_counter() - t0, end_count
)
elif logger.isEnabledFor(logging.INFO):
logger.info(
"Custom VJP: ✅ converged, took %.2f seconds. (Steps: %d)",
time.perf_counter() - t0, end_count
)
if end_count == varipeps_config.ad_custom_max_steps and not converged:
raise CTMRGGradientNotConvergedError

Expand Down
40 changes: 27 additions & 13 deletions varipeps/ctmrg/structure_factor_routine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
from jax import jit, custom_vjp, vjp, tree_util
from jax.lax import cond, while_loop
import jax.debug as jdebug
import logging
import time

logger = logging.getLogger("varipeps.ctmrg")

from varipeps import varipeps_config, varipeps_global_state
from varipeps.peps import PEPS_Tensor, PEPS_Unit_Cell
Expand Down Expand Up @@ -125,8 +129,8 @@ def _ctmrg_body_func_structure_factor(carry):
measure = jnp.linalg.norm(corner_svd - last_corner_svd)
converged = measure < eps

if config.ctmrg_print_steps:
debug_print("CTMRG: {}: {}", count, measure)
if logger.isEnabledFor(logging.DEBUG):
jax.debug.callback(lambda cnt, msr: logger.debug(f"CTMRG: Step {cnt}: {msr}"), count, measure, ordered=True)
if config.ctmrg_verbose_output:
for ti, ctm_enum_i, diff in verbose_data:
debug_print(
Expand Down Expand Up @@ -244,6 +248,7 @@ def calc_ctmrg_env_structure_factor(
norm_smallest_S = jnp.nan
already_tried_chi = {working_unitcell[0, 0][0][0].chi}

t0 = time.perf_counter()
while True:
tmp_count = 0
corner_singular_vals = None
Expand Down Expand Up @@ -304,6 +309,17 @@ def calc_ctmrg_env_structure_factor(
)
)

if not converged and logger.isEnabledFor(logging.WARNING):
logger.warning(
"CTMRG (SF): ❌ did not converge, took %.2f seconds. (Steps: %d, Smallest SVD Norm: %.3e)",
time.perf_counter() - t0, end_count, norm_smallest_S
)
elif logger.isEnabledFor(logging.INFO):
logger.info(
"CTMRG (SF): ✅ converged, took %.2f seconds. (Steps: %d, Smallest SVD Norm: %.3e)",
time.perf_counter() - t0, end_count, norm_smallest_S
)

current_truncation_eps = (
varipeps_config.ctmrg_truncation_eps
if varipeps_global_state.ctmrg_effective_truncation_eps is None
Expand All @@ -326,15 +342,14 @@ def calc_ctmrg_env_structure_factor(
working_unitcell = working_unitcell.change_chi(new_chi)
initial_unitcell = initial_unitcell.change_chi(new_chi)

if varipeps_config.ctmrg_print_steps:
debug_print(
"CTMRG: Increasing chi to {} since smallest SVD Norm was {}.",
if logger.isEnabledFor(logging.INFO):
logger.info(
"CTMRG (SF): Increasing chi to %d since smallest SVD Norm was %.3e.",
new_chi,
norm_smallest_S,
)

already_tried_chi.add(new_chi)

continue
elif (
varipeps_config.ctmrg_heuristic_decrease_chi
Expand All @@ -351,15 +366,14 @@ def calc_ctmrg_env_structure_factor(
if not new_chi in already_tried_chi:
working_unitcell = working_unitcell.change_chi(new_chi)

if varipeps_config.ctmrg_print_steps:
debug_print(
"CTMRG: Decreasing chi to {} since smallest SVD Norm was {}.",
if logger.isEnabledFor(logging.INFO):
logger.info(
"CTMRG (SF): Decreasing chi to %d since smallest SVD Norm was %.3e.",
new_chi,
norm_smallest_S,
)

already_tried_chi.add(new_chi)

continue

if (
Expand All @@ -375,9 +389,9 @@ def calc_ctmrg_env_structure_factor(
new_truncation_eps
<= varipeps_config.ctmrg_increase_truncation_eps_max_value
):
if varipeps_config.ctmrg_print_steps:
debug_print(
"CTMRG: Increasing SVD truncation eps to {}.",
if logger.isEnabledFor(logging.INFO):
logger.info(
"CTMRG (SF): Increasing SVD truncation eps to %g.",
new_truncation_eps,
)
varipeps_global_state.ctmrg_effective_truncation_eps = (
Expand Down
Loading