Skip to content

Commit

Permalink
Log skorch version in NeptuneLogger (#975)
Browse files Browse the repository at this point in the history
  • Loading branch information
AleksanderWWW committed Jun 2, 2023
1 parent f69be0f commit 5b222a5
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
13 changes: 10 additions & 3 deletions skorch/callbacks/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,11 @@ def __init__(
self.keys_ignored = keys_ignored
self.base_namespace = base_namespace

def _log_integration_version(self) -> None:
from skorch import __version__

self.run['source_code/integrations/skorch'] = __version__

@property
def _metric_logger(self):
return self.run[self._base_namespace]
Expand All @@ -196,6 +201,8 @@ def initialize(self):
else:
self._base_namespace = self.base_namespace

self._log_integration_version()

return self

def on_train_begin(self, net, X, y, **kwargs):
Expand Down Expand Up @@ -273,11 +280,11 @@ def _log_metric(self, name, logs, batch):
@staticmethod
def _model_summary_file(model):
try:
# neptune-client=0.9.0+ package structure
from neptune.new.types import File
except ImportError:
# neptune-client>=1.0.0 package structure
from neptune.types import File
except ImportError:
# neptune-client=0.9.0+ package structure
from neptune.new.types import File

return File.from_content(str(model), extension='txt')

Expand Down
3 changes: 3 additions & 0 deletions skorch/tests/callbacks/test_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ def net_fitted(
def test_experiment_closed_automatically(self, net_fitted, mock_experiment):
assert mock_experiment.stop.call_count == 1

def test_version_logged(self, net_fitted, mock_experiment):
assert mock_experiment.exists("source_code/integrations/skorch")

def test_experiment_log_call_counts(self, net_fitted, mock_experiment):
# (3 x dur + 3 x train_loss + 3 x valid_loss + 3 x valid_acc = 12) + base metrics
assert mock_experiment.__getitem__.call_count == 12 + self.NUM_BASE_METRICS
Expand Down

0 comments on commit 5b222a5

Please sign in to comment.