diff --git a/pyannote/audio/tasks/segmentation/mixins.py b/pyannote/audio/tasks/segmentation/mixins.py index b45d85d1f..beaa2bc63 100644 --- a/pyannote/audio/tasks/segmentation/mixins.py +++ b/pyannote/audio/tasks/segmentation/mixins.py @@ -30,6 +30,7 @@ import numpy as np import torch from pyannote.core import Segment, SlidingWindowFeature +from pytorch_lightning.loggers import TensorBoardLogger, MLFlowLogger from torch.utils.data._utils.collate import default_collate from torchmetrics import Metric from torchmetrics.classification import BinaryAUROC, MultilabelAUROC, MulticlassAUROC @@ -429,7 +430,7 @@ def validation_step(self, batch, batch_idx: int): ): return - # visualize first 9 validation samples of first batch in Tensorboard + # visualize first 9 validation samples of first batch in Tensorboard/MLflow X = X.cpu().numpy() y = y.float().cpu().numpy() y_pred = y_pred.cpu().numpy() @@ -478,8 +479,15 @@ def validation_step(self, batch, batch_idx: int): plt.tight_layout() - self.model.logger.experiment.add_figure( - f"{self.logging_prefix}ValSamples", fig, self.model.current_epoch - ) + if isinstance(self.model.logger, TensorBoardLogger): + self.model.logger.experiment.add_figure( + f"{self.logging_prefix}ValSamples", fig, self.model.current_epoch + ) + elif isinstance(self.model.logger, MLFlowLogger): + self.model.logger.experiment.log_figure( + run_id=self.model.logger.run_id, + figure=fig, + artifact_file=f"{self.logging_prefix}ValSamples_epoch{self.model.current_epoch}.png", + ) plt.close(fig)