Skip to content

Commit 062debf

Browse files
authored
[Integration] add swanlab logger (#10594)
* feat: add swanlabcallback * fix run * fix lint * fix requirements dev * fix url
1 parent c418bba commit 062debf

File tree

4 files changed

+119
-1
lines changed

4 files changed

+119
-1
lines changed

paddlenlp/trainer/integrations.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ def is_wandb_available():
4343
return importlib.util.find_spec("wandb") is not None
4444

4545

46+
def is_swanlab_available():
47+
return importlib.util.find_spec("swanlab") is not None
48+
49+
4650
def is_ray_available():
4751
return importlib.util.find_spec("ray.air") is not None
4852

@@ -55,6 +59,8 @@ def get_available_reporting_integrations():
5559
integrations.append("wandb")
5660
if is_tensorboardX_available():
5761
integrations.append("tensorboard")
62+
if is_swanlab_available():
63+
integrations.append("swanlab")
5864

5965
return integrations
6066

@@ -395,6 +401,85 @@ def on_save(self, args, state, control, **kwargs):
395401
self._wandb.log_artifact(artifact, aliases=[f"checkpoint-{state.global_step}"])
396402

397403

404+
class SwanLabCallback(TrainerCallback):
405+
"""
406+
A [`TrainerCallback`] that logs metrics, media to [Swanlab](https://swanlab.cn/).
407+
"""
408+
409+
def __init__(self):
410+
has_swanlab = is_swanlab_available()
411+
if not has_swanlab:
412+
raise RuntimeError("SwanlabCallback requires swanlab to be installed. Run `pip install swanlab`.")
413+
if has_swanlab:
414+
import swanlab
415+
416+
self._swanlab = swanlab
417+
418+
self._initialized = False
419+
420+
def setup(self, args, state, model, **kwargs):
421+
"""
422+
Setup the optional Swanlab integration.
423+
424+
One can subclass and override this method to customize the setup if needed.
425+
variables:
426+
Environment:
427+
- **SWANLAB_MODE** (`str`, *optional*, defaults to `"cloud"`):
428+
Whether to use swanlab cloud, local or disabled. Set `SWANLAB_MODE="local"` to use local. Set `SWANLAB_MODE="disabled"` to disable.
429+
- **SWANLAB_PROJECT** (`str`, *optional*, defaults to `"PaddleNLP"`):
430+
Set this to a custom string to store results in a different project.
431+
"""
432+
433+
if self._swanlab is None:
434+
return
435+
436+
self._initialized = True
437+
438+
if state.is_world_process_zero:
439+
logger.info('Automatic Swanlab logging enabled, to disable set os.environ["SWANLAB_MODE"] = "disabled"')
440+
441+
combined_dict = {**args.to_dict()}
442+
443+
if hasattr(model, "config") and model.config is not None:
444+
model_config = model.config.to_dict()
445+
combined_dict = {**model_config, **combined_dict}
446+
447+
trial_name = state.trial_name
448+
init_args = {}
449+
if trial_name is not None:
450+
init_args["name"] = trial_name
451+
init_args["group"] = args.run_name
452+
else:
453+
if not (args.run_name is None or args.run_name == args.output_dir):
454+
init_args["name"] = args.run_name
455+
init_args["dir"] = args.logging_dir
456+
if self._swanlab.get_run() is None:
457+
self._swanlab.init(
458+
project=os.getenv("SWANLAB_PROJECT", "PaddleNLP"),
459+
**init_args,
460+
)
461+
self._swanlab.config.update(combined_dict, allow_val_change=True)
462+
463+
def on_train_begin(self, args, state, control, model=None, **kwargs):
464+
if self._swanlab is None:
465+
return
466+
if not self._initialized:
467+
self.setup(args, state, model, **kwargs)
468+
469+
def on_train_end(self, args, state, control, model=None, tokenizer=None, **kwargs):
470+
if self._swanlab is None:
471+
return
472+
473+
def on_log(self, args, state, control, model=None, logs=None, **kwargs):
474+
if self._swanlab is None:
475+
return
476+
if not self._initialized:
477+
self.setup(args, state, model)
478+
if state.is_world_process_zero:
479+
logs = rewrite_logs(logs)
480+
self._swanlab.log({**logs, "train/global_step": state.global_step}, step=state.global_step)
481+
482+
398483
class AutoNLPCallback(TrainerCallback):
399484
"""
400485
A [`TrainerCallback`] that sends the logs to [`Ray Tune`] for [`AutoNLP`]
@@ -423,6 +508,7 @@ def on_evaluate(self, args, state, control, **kwargs):
423508
"autonlp": AutoNLPCallback,
424509
"wandb": WandbCallback,
425510
"tensorboard": TensorBoardCallback,
511+
"swanlab": SwanLabCallback,
426512
}
427513

428514

paddlenlp/trainer/training_args.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ class TrainingArguments:
376376
instance of `Dataset`.
377377
report_to (`str` or `List[str]`, *optional*, defaults to `"visualdl"`):
378378
The list of integrations to report the results and logs to.
379-
Supported platforms are `"visualdl"`/`"wandb"`/`"tensorboard"`.
379+
Supported platforms are `"visualdl"`/`"wandb"`/`"tensorboard"`/`"swanlab"`.
380380
`"none"` for no integrations.
381381
ddp_find_unused_parameters (`bool`, *optional*):
382382
When using distributed training, the value of the flag `find_unused_parameters` passed to

requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ rouge
1818
tiktoken
1919
visualdl
2020
wandb
21+
swanlab
2122
tensorboard
2223
tensorboardX
2324
modelscope

tests/trainer/test_trainer_visualization.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from paddlenlp.trainer import TrainerControl, TrainerState, TrainingArguments
2424
from paddlenlp.trainer.integrations import (
25+
SwanLabCallback,
2526
TensorBoardCallback,
2627
VisualDLCallback,
2728
WandbCallback,
@@ -66,6 +67,36 @@ def test_wandbcallback(self):
6667
shutil.rmtree(output_dir)
6768

6869

70+
class TestSwanlabCallback(unittest.TestCase):
71+
def test_swanlabcallback(self):
72+
output_dir = tempfile.mkdtemp()
73+
args = TrainingArguments(
74+
output_dir=output_dir,
75+
max_steps=200,
76+
logging_steps=20,
77+
run_name="test_swanlabcallback",
78+
logging_dir=output_dir,
79+
)
80+
state = TrainerState(trial_name="PaddleNLP")
81+
control = TrainerControl()
82+
config = RegressionModelConfig(a=1, b=1)
83+
model = RegressionPretrainedModel(config)
84+
os.environ["SWANLAB_MODE"] = "disabled"
85+
swanlabcallback = SwanLabCallback()
86+
self.assertFalse(swanlabcallback._initialized)
87+
swanlabcallback.on_train_begin(args, state, control)
88+
self.assertTrue(swanlabcallback._initialized)
89+
for global_step in range(args.max_steps):
90+
state.global_step = global_step
91+
if global_step % args.logging_steps == 0:
92+
log = {"loss": 100 - 0.4 * global_step, "learning_rate": 0.1, "global_step": global_step}
93+
swanlabcallback.on_log(args, state, control, logs=log)
94+
swanlabcallback.on_train_end(args, state, control, model=model)
95+
swanlabcallback._swanlab.finish()
96+
os.environ.pop("SWANLAB_MODE", None)
97+
shutil.rmtree(output_dir)
98+
99+
69100
class TestTensorboardCallback(unittest.TestCase):
70101
def test_tbcallback(self):
71102
output_dir = tempfile.mkdtemp()

0 commit comments

Comments
 (0)