forked from SkafteNicki/pl_crossvalidate
/
loggers.py
50 lines (35 loc) · 1.38 KB
/
loggers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from pytorch_lightning import loggers as pl_loggers
class KFoldLogger:
def setup(self):
""" Additional setup code to inject during __init__ """
self._fold_idx = 0
self._version = f"fold{self._fold_idx}"
def increment(self):
""" Will run after an fold has been executed """
self._experiment = None
self._fold_idx += 1
self._version = f"fold{self._fold_idx}"
class CometLogger(pl_loggers.CometLogger, KFoldLogger):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.setup()
class CSVLogger(pl_loggers.CSVLogger, KFoldLogger):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.setup()
class NeptuneLogger(pl_loggers.MLFlowLogger, KFoldLogger):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.setup()
class TensorboardLogger(pl_loggers.TensorBoardLogger, KFoldLogger):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.setup()
class TestTubeLogger(pl_loggers.TestTubeLogger, KFoldLogger):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.setup()
class WandbLogger(pl_loggers.WandbLogger, KFoldLogger):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.setup()