diff --git a/minerva/tasks.py b/minerva/tasks.py index 6e94f78..58d8716 100644 --- a/minerva/tasks.py +++ b/minerva/tasks.py @@ -52,6 +52,10 @@ def train(algo_name, augmentation = DrQPipeline(augmentations, n_mean=n_mean) params['augmentation'] = augmentation + # add action scaler if continuous action-space + if not dataset.is_action_discrete(): + params['action_scaler'] = 'min_max' + # train algo = create_algo(algo_name, dataset.is_action_discrete(), **params) algo.fit(train_data,