Skip to content

Commit

Permalink
fix(wandb): fix wandb config format and run name (#317)
Browse files Browse the repository at this point in the history
This commit ensures that the wandb config format is human readable and fixes the 'run_name' not
being applied correctly.
  • Loading branch information
rickstaa committed Aug 8, 2023
1 parent 44a35af commit ca048de
Show file tree
Hide file tree
Showing 7 changed files with 136 additions and 66 deletions.
9 changes: 6 additions & 3 deletions stable_learning_control/algos/pytorch/lac/lac.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,7 +948,7 @@ def lac(
# Setup algorithm parameters.
total_steps = steps_per_epoch * epochs
env = env_fn()
hyper_param_dict["env"] = get_env_id(env) # Store env id in hyperparameter dict.
hyper_param_dict["env"] = env # Add env to hyperparameters.

# Validate gymnasium env.
# NOTE: The current implementation only works with continuous spaces.
Expand Down Expand Up @@ -998,6 +998,11 @@ def lac(
else False
)
use_wandb = logger_kwargs.get("use_wandb")
if use_wandb and not logger_kwargs.get("wandb_run_name"):
# Create wandb_run_name if wandb is used and no name is provided.
logger_kwargs["wandb_run_name"] = PurePath(logger_kwargs["output_dir"]).parts[
-1
]
logger = EpochLogger(**logger_kwargs)

# Retrieve max episode length.
Expand Down Expand Up @@ -1687,8 +1692,6 @@ def lac(
f"../../../../../data/lac/{args.env.lower()}/runs/run_{int(time.time())}",
)
)
if args.use_wandb: # Add the wandb run name to the logger kwargs.
logger_kwargs["wandb_run_name"] = PurePath(logger_kwargs["ouput_dir"]).parts[-1]
torch.set_num_threads(torch.get_num_threads())

lac(
Expand Down
9 changes: 6 additions & 3 deletions stable_learning_control/algos/pytorch/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,7 +893,7 @@ def sac(
# Setup algorithm parameters.
total_steps = steps_per_epoch * epochs
env = env_fn()
hyper_param_dict["env"] = get_env_id(env) # Store env id in hyperparameter dict.
hyper_param_dict["env"] = env # Add env to hyperparameters.

# Validate gymnasium env.
# NOTE: The current implementation only works with continuous spaces.
Expand Down Expand Up @@ -943,6 +943,11 @@ def sac(
else False
)
use_wandb = logger_kwargs.get("use_wandb")
if use_wandb and not logger_kwargs.get("wandb_run_name"):
# Create wandb_run_name if wandb is used and no name is provided.
logger_kwargs["wandb_run_name"] = PurePath(logger_kwargs["output_dir"]).parts[
-1
]
logger = EpochLogger(**logger_kwargs)

# Retrieve max episode length.
Expand Down Expand Up @@ -1593,8 +1598,6 @@ def sac(
f"../../../../../data/sac/{args.env.lower()}/runs/run_{int(time.time())}",
)
)
if args.use_wandb: # Add the wandb run name to the logger kwargs.
logger_kwargs["wandb_run_name"] = PurePath(logger_kwargs["ouput_dir"]).parts[-1]
torch.set_num_threads(torch.get_num_threads())

sac(
Expand Down
9 changes: 6 additions & 3 deletions stable_learning_control/algos/tf2/lac/lac.py
Original file line number Diff line number Diff line change
Expand Up @@ -898,7 +898,7 @@ def lac(
# Setup algorithm parameters.
total_steps = steps_per_epoch * epochs
env = env_fn()
hyper_param_dict["env"] = get_env_id(env) # Store env id in hyperparameter dict.
hyper_param_dict["env"] = env # Add env to hyperparameters.

# Validate gymnasium env.
# NOTE: The current implementation only works with continuous spaces.
Expand Down Expand Up @@ -948,6 +948,11 @@ def lac(
if "use_tensorboard" in logger_kwargs.keys()
else False
)
if logger_kwargs.get("use_wandb") and not logger_kwargs.get("wandb_run_name"):
# Create wandb_run_name if wandb is used and no name is provided.
logger_kwargs["wandb_run_name"] = PurePath(logger_kwargs["output_dir"]).parts[
-1
]
logger = EpochLogger(**logger_kwargs)

# Retrieve max episode length.
Expand Down Expand Up @@ -1627,8 +1632,6 @@ def lac(
f"../../../../../data/lac/{args.env.lower()}/runs/run_{int(time.time())}",
)
)
if args.use_wandb: # Add the wandb run name to the logger kwargs.
logger_kwargs["wandb_run_name"] = PurePath(logger_kwargs["ouput_dir"]).parts[-1]

lac(
lambda: gym.make(args.env),
Expand Down
9 changes: 6 additions & 3 deletions stable_learning_control/algos/tf2/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,7 +846,7 @@ def sac(
# Setup algorithm parameters.
total_steps = steps_per_epoch * epochs
env = env_fn()
hyper_param_dict["env"] = get_env_id(env) # Store env id in hyperparameter dict.
hyper_param_dict["env"] = env # Add env to hyperparameters.

# Validate gymnasium env.
# NOTE: The current implementation only works with continuous spaces.
Expand Down Expand Up @@ -896,6 +896,11 @@ def sac(
if "use_tensorboard" in logger_kwargs.keys()
else False
)
if logger_kwargs.get("use_wandb") and not logger_kwargs.get("wandb_run_name"):
# Create wandb_run_name if wandb is used and no name is provided.
logger_kwargs["wandb_run_name"] = PurePath(logger_kwargs["output_dir"]).parts[
-1
]
logger = EpochLogger(**logger_kwargs)

# Retrieve max episode length.
Expand Down Expand Up @@ -1545,8 +1550,6 @@ def sac(
f"../../../../../data/sac/{args.env.lower()}/runs/run_{int(time.time())}",
)
)
if args.use_wandb: # Add the wandb run name to the logger kwargs.
logger_kwargs["wandb_run_name"] = PurePath(logger_kwargs["ouput_dir"]).parts[-1]

sac(
lambda: gym.make(args.env),
Expand Down
105 changes: 100 additions & 5 deletions stable_learning_control/common/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import string
from collections.abc import Iterable, MutableMapping

import gymnasium as gym
import numpy as np
import torch

Expand Down Expand Up @@ -220,11 +221,51 @@ def get_env_id(env):
Returns:
str: The environment id.
"""
return (
env.unwrapped.spec.id
if hasattr(env.unwrapped.spec, "id")
else type(env.unwrapped).__name__
)
if isinstance(env, gym.Env):
return (
env.unwrapped.spec.id
if hasattr(env.unwrapped.spec, "id")
else type(env.unwrapped).__name__
)
return env


def get_env_class(env):
"""Get the environment class.
Args:
env (:obj:`gym.Env`): The environment.
Returns:
str: The environment class.
"""
if isinstance(env, gym.Env):
return "{}.{}".format(
env.unwrapped.__module__, env.unwrapped.__class__.__name__
)
return env


def parse_config_env_key(config):
"""Replace environment objects (i.e. gym.Env) with their id and class path if they
are present in the config. Also removes the 'env_fn' from the config.
Args:
config (dict): The configuration dictionary.
Returns:
dict: The parsed configuration dictionary.
"""
parsed_config = {}
for key, val in config.items():
if key == "env" and isinstance(val, gym.Env):
parsed_config[key] = get_env_id(val)
parsed_config["env_class"] = get_env_class(val)
elif key == "env_fn": # Remove env_fn from config.
continue
else:
parsed_config[key] = val
return parsed_config


def convert_to_snake_case(input_str):
Expand Down Expand Up @@ -278,3 +319,57 @@ def flatten_dict(d, parent_key="", sep="."):
else:
items.append((new_key, v))
return dict(items)


def convert_to_wandb_config(config):
"""Transform the config to a format that looks better on Weights & Biases.
Args:
config (dict): The config that should be transformed.
Returns:
dict: The transformed config.
"""
wandb_config = {}
for key, value in config.items():
if (
key
in [
"env_fn",
"output_dir",
"use_wandb",
"wandb_job_type",
"wandb_project",
"wandb_group",
"wandb_run_name",
]
or value is None
): # Filter keys.
continue
elif key in ["policy", "disturber"]: # Transform policy object to policy id.
value = "{}.{}".format(value.__module__, value.__class__.__name__)
elif key == "env" and isinstance(value, gym.Env):
wandb_config["env_class"] = get_env_class(value)
value = get_env_id(value)
wandb_config[key] = value
return wandb_config


def convert_to_tb_config(config):
"""Transform the config to a format that looks better on TensorBoard.
Args:
config (dict): The config that should be transformed.
Returns:
dict: The transformed config.
"""
tb_config = {}
for key, value in config.items():
if key in ["env_fn"]: # Skip env_fn.
continue
elif key == "env" and isinstance(value, gym.Env):
tb_config["env_class"] = get_env_class(value)
value = get_env_id(value)
tb_config[key] = value
return flatten_dict(tb_config)
35 changes: 2 additions & 33 deletions stable_learning_control/utils/eval_robustness.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
more information.
"""
import ast
import copy
import importlib
import inspect
import os
import sys
from pathlib import Path, PurePath
from textwrap import dedent
import copy

import gymnasium as gym
import matplotlib.pyplot as plt
Expand All @@ -19,6 +19,7 @@

from stable_learning_control.common.helpers import (
convert_to_snake_case,
convert_to_wandb_config,
friendly_err,
get_env_id,
)
Expand All @@ -27,38 +28,6 @@
from stable_learning_control.utils.test_policy import load_policy_and_env


def convert_to_wandb_config(config):
"""Transform the config to a format that looks better on Weights & Biases.
Args:
config (dict): The config that should be transformed.
Returns:
dict: The transformed config.
"""
wandb_config = {}
for key, value in config.items():
if (
key
in [
"output_dir",
"use_wandb",
"wandb_job_type",
"wandb_project",
"wandb_group",
"wandb_run_name",
]
or value is None
): # Filter keys.
continue
elif key == "env": # Transform env object to env id.
value = get_env_id(value)
elif key in ["policy", "disturber"]: # Transform policy object to policy id.
value = "{}.{}".format(value.__module__, value.__class__.__name__)
wandb_config[key] = value
return wandb_config


def get_human_readable_disturber_label(disturber_label):
"""Get a human readable label for a given disturber label.
Expand Down
26 changes: 10 additions & 16 deletions stable_learning_control/utils/log_utils/logx.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
- Logs to Weights & Biases (besides logging to a file).
""" # noqa
import atexit
import copy
import glob
import json
import os
Expand All @@ -19,14 +20,18 @@
import re
import time
from pathlib import Path
import copy

import joblib
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter

from stable_learning_control.common.helpers import is_scalar, flatten_dict
from stable_learning_control.common.helpers import (
convert_to_tb_config,
convert_to_wandb_config,
is_scalar,
parse_config_env_key,
)
from stable_learning_control.user_config import (
DEFAULT_STD_OUT_TYPE,
PRINT_CONFIG,
Expand Down Expand Up @@ -529,6 +534,7 @@ def save_config(self, config):
config (object): Configuration Python object you want to save.
"""
if proc_id() == 0:
config = parse_config_env_key(config)
self._config = config
config_json = convert_json(config)
if self.exp_name is not None:
Expand Down Expand Up @@ -1027,13 +1033,7 @@ def _global_step(self):
def _wandb_config(self):
"""Transform the config to a format that looks better on Weights & Biases."""
if self.wandb and self._config:
wandb_config = {}
for key, value in self._config.items():
if key in ["env_fn"]: # Skip env_fn.
continue
else:
wandb_config[key] = value
return wandb_config
return convert_to_wandb_config(self._config)
return None

def watch_model_in_wandb(self, model):
Expand Down Expand Up @@ -1098,13 +1098,7 @@ def _log_wandb_artifacts(self):
def _tb_config(self):
"""Modify the config to a format that looks better on Tensorboard."""
if self.use_tensorboard and self._config:
tb_config = {}
for key, value in self._config.items():
if key in ["env_fn"]: # Skip env_fn.
continue
else:
tb_config[key] = value
return flatten_dict(tb_config)
return convert_to_tb_config(self._config)
return None

@property
Expand Down

0 comments on commit ca048de

Please sign in to comment.