diff --git a/torchmdnet/scripts/train.py b/torchmdnet/scripts/train.py index 44ffed9cb..52eacb250 100644 --- a/torchmdnet/scripts/train.py +++ b/torchmdnet/scripts/train.py @@ -146,12 +146,19 @@ def main(): early_stopping = EarlyStopping("val_loss", patience=args.early_stopping_patience) csv_logger = CSVLogger(args.log_dir, name="", version="") - _logger=[csv_logger] + _logger = [csv_logger] if args.wandb_use: if args.wandb_resume and args.wandb_id is not None: - wandb_logger=WandbLogger(project=args.wandb_project, save_dir=args.log_dir, resume='must', id=args.wandb_id) + wandb_logger = WandbLogger( + project=args.wandb_project, + save_dir=args.log_dir, + resume="must", + id=args.wandb_id, + ) else: - wandb_logger=WandbLogger(project=args.wandb_project,name=args.wandb_name, save_dir=args.log_dir) + wandb_logger = WandbLogger( + project=args.wandb_project, name=args.wandb_name, save_dir=args.log_dir + ) _logger.append(wandb_logger) if args.tensorboard_use: