@@ -43,6 +43,10 @@ def is_wandb_available():
43
43
return importlib .util .find_spec ("wandb" ) is not None
44
44
45
45
46
+ def is_swanlab_available ():
47
+ return importlib .util .find_spec ("swanlab" ) is not None
48
+
49
+
46
50
def is_ray_available ():
47
51
return importlib .util .find_spec ("ray.air" ) is not None
48
52
@@ -55,6 +59,8 @@ def get_available_reporting_integrations():
55
59
integrations .append ("wandb" )
56
60
if is_tensorboardX_available ():
57
61
integrations .append ("tensorboard" )
62
+ if is_swanlab_available ():
63
+ integrations .append ("swanlab" )
58
64
59
65
return integrations
60
66
@@ -395,6 +401,85 @@ def on_save(self, args, state, control, **kwargs):
395
401
self ._wandb .log_artifact (artifact , aliases = [f"checkpoint-{ state .global_step } " ])
396
402
397
403
404
+ class SwanLabCallback (TrainerCallback ):
405
+ """
406
+ A [`TrainerCallback`] that logs metrics, media to [Swanlab](https://swanlab.cn/).
407
+ """
408
+
409
+ def __init__ (self ):
410
+ has_swanlab = is_swanlab_available ()
411
+ if not has_swanlab :
412
+ raise RuntimeError ("SwanlabCallback requires swanlab to be installed. Run `pip install swanlab`." )
413
+ if has_swanlab :
414
+ import swanlab
415
+
416
+ self ._swanlab = swanlab
417
+
418
+ self ._initialized = False
419
+
420
+ def setup (self , args , state , model , ** kwargs ):
421
+ """
422
+ Setup the optional Swanlab integration.
423
+
424
+ One can subclass and override this method to customize the setup if needed.
425
+ variables:
426
+ Environment:
427
+ - **SWANLAB_MODE** (`str`, *optional*, defaults to `"cloud"`):
428
+ Whether to use swanlab cloud, local or disabled. Set `SWANLAB_MODE="local"` to use local. Set `SWANLAB_MODE="disabled"` to disable.
429
+ - **SWANLAB_PROJECT** (`str`, *optional*, defaults to `"PaddleNLP"`):
430
+ Set this to a custom string to store results in a different project.
431
+ """
432
+
433
+ if self ._swanlab is None :
434
+ return
435
+
436
+ self ._initialized = True
437
+
438
+ if state .is_world_process_zero :
439
+ logger .info ('Automatic Swanlab logging enabled, to disable set os.environ["SWANLAB_MODE"] = "disabled"' )
440
+
441
+ combined_dict = {** args .to_dict ()}
442
+
443
+ if hasattr (model , "config" ) and model .config is not None :
444
+ model_config = model .config .to_dict ()
445
+ combined_dict = {** model_config , ** combined_dict }
446
+
447
+ trial_name = state .trial_name
448
+ init_args = {}
449
+ if trial_name is not None :
450
+ init_args ["name" ] = trial_name
451
+ init_args ["group" ] = args .run_name
452
+ else :
453
+ if not (args .run_name is None or args .run_name == args .output_dir ):
454
+ init_args ["name" ] = args .run_name
455
+ init_args ["dir" ] = args .logging_dir
456
+ if self ._swanlab .get_run () is None :
457
+ self ._swanlab .init (
458
+ project = os .getenv ("SWANLAB_PROJECT" , "PaddleNLP" ),
459
+ ** init_args ,
460
+ )
461
+ self ._swanlab .config .update (combined_dict , allow_val_change = True )
462
+
463
+ def on_train_begin (self , args , state , control , model = None , ** kwargs ):
464
+ if self ._swanlab is None :
465
+ return
466
+ if not self ._initialized :
467
+ self .setup (args , state , model , ** kwargs )
468
+
469
+ def on_train_end (self , args , state , control , model = None , tokenizer = None , ** kwargs ):
470
+ if self ._swanlab is None :
471
+ return
472
+
473
+ def on_log (self , args , state , control , model = None , logs = None , ** kwargs ):
474
+ if self ._swanlab is None :
475
+ return
476
+ if not self ._initialized :
477
+ self .setup (args , state , model )
478
+ if state .is_world_process_zero :
479
+ logs = rewrite_logs (logs )
480
+ self ._swanlab .log ({** logs , "train/global_step" : state .global_step }, step = state .global_step )
481
+
482
+
398
483
class AutoNLPCallback (TrainerCallback ):
399
484
"""
400
485
A [`TrainerCallback`] that sends the logs to [`Ray Tune`] for [`AutoNLP`]
@@ -423,6 +508,7 @@ def on_evaluate(self, args, state, control, **kwargs):
423
508
"autonlp" : AutoNLPCallback ,
424
509
"wandb" : WandbCallback ,
425
510
"tensorboard" : TensorBoardCallback ,
511
+ "swanlab" : SwanLabCallback ,
426
512
}
427
513
428
514
0 commit comments