Skip to content

Commit

Permalink
Solved minor bug with MLFlow logger (Lightning-AI#16418)
Browse files Browse the repository at this point in the history
  • Loading branch information
BrianPulfer committed Jan 20, 2023
1 parent d3de5c6 commit 6fd914f
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 2 deletions.
3 changes: 3 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed an unintended limitation for calling `save_hyperparameters` on mixin classes that don't subclass `LightningModule`/`LightningDataModule` ([#16369](https://github.com/Lightning-AI/lightning/pull/16369))

- Fixed an issue with `MLFlowLogger` logging the wrong keys with `.log_hyperparams()` ([#16418](https://github.com/Lightning-AI/lightning/pull/16418))



## [1.9.0] - 2023-01-17

Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/loggers/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
f"Mlflow only allows parameters with up to 250 characters. Discard {k}={v}", category=RuntimeWarning
)
continue
params_list.append(Param(key=v, value=v))
params_list.append(Param(key=k, value=v))

self.experiment.log_batch(run_id=self.run_id, params=params_list)

Expand Down
4 changes: 3 additions & 1 deletion tests/tests_pytorch/loggers/test_mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,15 +253,17 @@ def test_mlflow_logger_experiment_calls(client, _, time, param, metric, tmpdir):
logger.log_hyperparams(params)

logger.experiment.log_batch.assert_called_once_with(
run_id=logger.run_id, params=[param(key="test_param", value="test_param")]
run_id=logger.run_id, params=[param(key="test", value="test_param")]
)
param.assert_called_with(key="test", value="test_param")

metrics = {"some_metric": 10}
logger.log_metrics(metrics)

logger.experiment.log_batch.assert_called_with(
run_id=logger.run_id, metrics=[metric(key="some_metric", value=10, timestamp=1000, step=0)]
)
metric.assert_called_with(key="some_metric", value=10, timestamp=1000, step=0)

logger._mlflow_client.create_experiment.assert_called_once_with(
name="test", artifact_location="my_artifact_location"
Expand Down

0 comments on commit 6fd914f

Please sign in to comment.