From bdd79c44eb73609c43c390a2051bf621ea91bc10 Mon Sep 17 00:00:00 2001 From: Rick Staa Date: Fri, 15 Mar 2024 13:58:22 +0100 Subject: [PATCH] refactor: apply black formatting --- .../algos/pytorch/common/get_lr_scheduler.py | 4 +--- stable_learning_control/algos/tf2/sac/sac.py | 4 +--- stable_learning_control/utils/plot.py | 4 +++- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/stable_learning_control/algos/pytorch/common/get_lr_scheduler.py b/stable_learning_control/algos/pytorch/common/get_lr_scheduler.py index a2935d10..025f0189 100644 --- a/stable_learning_control/algos/pytorch/common/get_lr_scheduler.py +++ b/stable_learning_control/algos/pytorch/common/get_lr_scheduler.py @@ -149,9 +149,7 @@ def estimate_step_learning_rate( decay_rate = get_exponential_decay_rate( lr_start, lr_final, adjusted_total_steps ) - lr = float( - Decimal(lr_start) * (Decimal(decay_rate) ** Decimal(adjusted_step)) - ) + lr = float(Decimal(lr_start) * (Decimal(decay_rate) ** Decimal(adjusted_step))) else: supported_schedulers = ["LambdaLR", "ExponentialLR"] raise ValueError( diff --git a/stable_learning_control/algos/tf2/sac/sac.py b/stable_learning_control/algos/tf2/sac/sac.py index da2fcb3a..e5206d12 100644 --- a/stable_learning_control/algos/tf2/sac/sac.py +++ b/stable_learning_control/algos/tf2/sac/sac.py @@ -1023,9 +1023,7 @@ def sac( type="warning", ) decay_types[name] = lr_decay_type - lr_a_decay_type, lr_c_decay_type, lr_alpha_decay_type = ( - decay_types.values() - ) + lr_a_decay_type, lr_c_decay_type, lr_alpha_decay_type = decay_types.values() # Calculate the number of learning rate scheduler steps. if lr_decay_ref == "step": diff --git a/stable_learning_control/utils/plot.py b/stable_learning_control/utils/plot.py index 9770ad57..c03f1d57 100644 --- a/stable_learning_control/utils/plot.py +++ b/stable_learning_control/utils/plot.py @@ -76,7 +76,9 @@ def plot_data( if isinstance(data, list): data = pd.concat(data, ignore_index=True) sns.set(style=style, font_scale=font_scale) - sns.lineplot(data=data, x=xaxis, y=value, hue=condition, errorbar=errorbar, **kwargs) + sns.lineplot( + data=data, x=xaxis, y=value, hue=condition, errorbar=errorbar, **kwargs + ) plt.legend(loc="best").set_draggable(True) xscale = np.max(np.asarray(data[xaxis])) > 5e3