From 2ecdc23c2129a132d5ef7432182a9619730f02ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Herv=C3=A9=20BREDIN?= Date: Tue, 22 Jun 2021 14:41:57 +0200 Subject: [PATCH] chore: make tensorboard logs lighter (#687) --- pyannote/audio/tasks/segmentation/mixins.py | 8 ++++++-- pyannote/audio/tasks/segmentation/segmentation.py | 14 +++++++++----- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/pyannote/audio/tasks/segmentation/mixins.py b/pyannote/audio/tasks/segmentation/mixins.py index 63358434f..0f4cc544e 100644 --- a/pyannote/audio/tasks/segmentation/mixins.py +++ b/pyannote/audio/tasks/segmentation/mixins.py @@ -495,8 +495,12 @@ def validation_step(self, batch, batch_idx: int): logger=True, ) - # log first batch visualization every 10 epochs. - if self.model.current_epoch % 10 > 0 or batch_idx > 0: + # log first batch visualization every 2^n epochs. + if ( + self.model.current_epoch == 0 + or math.log2(self.model.current_epoch) % 1 > 0 + or batch_idx > 0 + ): return # visualize first 9 validation samples of first batch in Tensorboard diff --git a/pyannote/audio/tasks/segmentation/segmentation.py b/pyannote/audio/tasks/segmentation/segmentation.py index fca6b62f0..7befb054d 100644 --- a/pyannote/audio/tasks/segmentation/segmentation.py +++ b/pyannote/audio/tasks/segmentation/segmentation.py @@ -337,7 +337,7 @@ def training_step(self, batch, batch_idx: int): self.model.log( f"{self.ACRONYM}@train_seg_loss", seg_loss, - on_step=True, + on_step=False, on_epoch=True, prog_bar=False, logger=True, @@ -354,7 +354,7 @@ def training_step(self, batch, batch_idx: int): self.model.log( f"{self.ACRONYM}@train_vad_loss", vad_loss, - on_step=True, + on_step=False, on_epoch=True, prog_bar=False, logger=True, @@ -365,7 +365,7 @@ def training_step(self, batch, batch_idx: int): self.model.log( f"{self.ACRONYM}@train_loss", loss, - on_step=True, + on_step=False, on_epoch=True, prog_bar=True, logger=True, @@ -414,8 +414,12 @@ def validation_step(self, batch, batch_idx: int): logger=True, ) - # log first batch visualization every 10 epochs. - if self.model.current_epoch % 10 > 0 or batch_idx > 0: + # log first batch visualization every 2^n epochs. + if ( + self.model.current_epoch == 0 + or math.log2(self.model.current_epoch) % 1 > 0 + or batch_idx > 0 + ): return # visualize first 9 validation samples of first batch in Tensorboard