Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,16 @@ Changelog
v1.4.3 (**.**.2025)
===================

Feature
-------

- Added stagger_logging functionality to logtables.
- Allow condition for experimenter.get_table() and experimenter.get_logtable()

Fix
===

- Fix bug, where the logtable_name was not overwritten by `table_name` updates in the `PyExperimenter` class.
- Allow condition for experimenter.get_table() and experimenter.get_logtable()

v1.4.2 (12.06.2024)
===================
Expand Down
20 changes: 17 additions & 3 deletions py_experimenter/experimenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ def __init__(
use_ssh_tunnel: Optional[bool] = None,
table_name: Optional[str] = None,
database_name: Optional[str] = None,
stagger_logging: bool = False,
log_every_n_seconds: int = None,
use_codecarbon: bool = True,
name="PyExperimenter",
logger_name: str = "py-experimenter",
Expand Down Expand Up @@ -60,6 +62,10 @@ def __init__(
`experiment_configuration_file_path`. If None, the database name is taken from the experiment configuration
file. Defaults to None.
:type database_name: str, optional
:param stagger_logging: If True, the logs are written to the database every `log_every_n_seconds` seconds. Defaults to False.
:type stagger_logging: bool, optional
:param log_every_n_seconds: The time interval in seconds at which the logs are written to the database. Defaults to 10.
:type log_every_n_seconds: int, optional
:param use_codecarbon: If True, the carbon emissions are tracked and stored in the database. Defaults to True.
:type use_codecarbon: bool, optional
:param name: The name of the PyExperimenter, which will be logged in the according column in the database table.
Expand Down Expand Up @@ -96,6 +102,13 @@ def __init__(
handler.setFormatter(formatter)
self.logger.addHandler(handler)

self.stagger_logging = stagger_logging
self.log_every_n_seconds = log_every_n_seconds
if self.stagger_logging and (log_every_n_seconds is None or not isinstance(log_every_n_seconds, int)):
raise ValueError("log_every_n_seconds must be set to an integer when stagger_logging is True")
if self.stagger_logging and log_every_n_seconds <= 0:
raise ValueError("log_every_n_seconds must be greater than 0 when stagger_logging is True and log_every_n_seconds is set")

self.config = PyExperimenterCfg.extract_config(experiment_configuration_file_path, logger=self.logger, overwritten_table_name=table_name)

self.use_codecarbon = use_codecarbon
Expand Down Expand Up @@ -378,7 +391,7 @@ def _worker(self, experiment_function: Callable[[Dict, Dict, ResultProcessor], N
break

def _execution_wrapper(
self, experiment_function: Callable[[Dict, Dict, ResultProcessor], Optional[ExperimentStatus]], random_order: bool
self, experiment_function: Callable[[Dict, ResultProcessor, Dict], Optional[ExperimentStatus]], random_order: bool
) -> None:
"""
Executes the given `experiment_function` on one open experiment. To that end, one of the open experiments is pulled
Expand All @@ -397,7 +410,7 @@ def _execution_wrapper(
and do not appear in the table. Additionally errors due to returning `ExperimentStatus.ERROR` are not logged.

:param experiment_function: The function that should be executed with the different parametrizations.
:type experiment_function: Callable[[dict, dict, ResultProcessor], None]
:type experiment_function: Callable[[dict, ResultProcessor, dict], None]
:param random_order: If True, the order of the experiments is determined randomly. Defaults to False.
:type random_order: bool
:raises NoExperimentsLeftError: If there are no experiments left to be executed.
Expand All @@ -407,7 +420,7 @@ def _execution_wrapper(
self._execute_experiment(experiment_id, keyfield_values, experiment_function)

def _execute_experiment(self, experiment_id, keyfield_values, experiment_function):
result_processor = ResultProcessor(self.config.database_configuration, self.db_connector, experiment_id=experiment_id, logger=self.logger)
result_processor = ResultProcessor(self.config.database_configuration, self.db_connector, experiment_id=experiment_id, logger=self.logger, stagger_logging=self.stagger_logging, log_every_n_seconds=self.log_every_n_seconds)
result_processor._set_name(self.name)
result_processor._set_machine(socket.gethostname())

Expand Down Expand Up @@ -449,6 +462,7 @@ def _execute_experiment(self, experiment_id, keyfield_values, experiment_functio
tracker.stop()
emission_data = tracker._prepare_emissions_data().values
result_processor._write_emissions(emission_data, self.codecarbon_offline_mode)
result_processor.write_logs(force_write=True)

def _write_codecarbon_config(self) -> None:
""" "
Expand Down
49 changes: 42 additions & 7 deletions py_experimenter/result_processor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import time
from configparser import ConfigParser
from copy import deepcopy
from typing import Dict, List, Tuple
Expand All @@ -10,7 +11,11 @@
from py_experimenter.database_connector import DatabaseConnector
from py_experimenter.database_connector_lite import DatabaseConnectorLITE
from py_experimenter.database_connector_mysql import DatabaseConnectorMYSQL
from py_experimenter.exceptions import InvalidConfigError, InvalidLogFieldError, InvalidResultFieldError
from py_experimenter.exceptions import (
InvalidConfigError,
InvalidLogFieldError,
InvalidResultFieldError,
)


class ResultProcessor:
Expand All @@ -19,12 +24,23 @@ class ResultProcessor:
database.
"""

def __init__(self, database_config: DatabaseCfg, db_connector: DatabaseConnector, experiment_id: int, logger):
def __init__(self, database_config: DatabaseCfg, db_connector: DatabaseConnector, experiment_id: int, logger, stagger_logging: bool = False, log_every_n_seconds: int = None):
self.logger = logger
self.database_config = database_config
self.db_connector = db_connector
self.experiment_id = experiment_id
self.experiment_id_condition = f"ID = {self.experiment_id}"
self.stagger_logging = stagger_logging
self.log_every_n_seconds = log_every_n_seconds
self.last_log_time = time.time()
self.log_queries = []
self.log_counter = 0

if self.stagger_logging:
if self.log_every_n_seconds is None or not isinstance(self.log_every_n_seconds, int):
raise ValueError("log_every_n_seconds must be set to an integer when stagger_logging is True")
if self.log_every_n_seconds <= 0:
raise ValueError("log_every_n_seconds must be greater than 0 when stagger_logging is True")

def process_results(self, results: Dict) -> None:
"""
Expand Down Expand Up @@ -66,27 +82,46 @@ def _add_timestamps_to_results(results: Dict) -> List[Tuple[str, object]]:
result_fields_with_timestep[f"{result_field}_timestamp"] = time
return result_fields_with_timestep

def process_logs(self, logs: Dict[str, Dict[str, str]]) -> None:
def process_logs(self, logs: Dict[str, Dict[str, str]], force_write: bool = False) -> None:
"""
Appends logs to the logtables. Raises InvalidLogFieldError if the given logs are invalid.
The logs are of the following structure: Dictionary keys are the logtable_names (without the prefix `table_name__`). Each key refers to a inner dictionary
with the keys as columnsnames and values as results.
with the keys as columnsnames and values as results. The logs are first collected and then written to the database jointly. To that end at every call to
this method, it is checked if the time since the last write is greater than the log_every_n_seconds. If so, the logs are written to the database.
If force_write is True, or when the result_processor is closed, the logs are written to the database.

:param logs: Logs to be appended to the logtables.
:type logs: Dict[str, Dict[str, str]]
"""
if not self._valid_logtable_logs(logs):
raise InvalidLogFieldError("Invalid logtable entries. See logs for more information")

queries = []
time = utils.get_timestamp_representation()
for logtable_identifier, log_entries in logs.items():
logtable_name = f"{self.database_config.table_name}__{logtable_identifier}"
log_entries["experiment_id"] = str(self.experiment_id)
log_entries["timestamp"] = f"{time}"
stmt = self.db_connector.prepare_write_query(logtable_name, log_entries.keys())
queries.append((stmt, log_entries.values()))
self.db_connector.execute_queries(queries)
self.log_queries.append((stmt, log_entries.values()))

self.write_logs(force_write)

def write_logs(self, force_write: bool = False) -> None:
# Determine if we should write logs now
if force_write:
should_write = True
elif self.stagger_logging:
time_since_last_write = time.time() - self.last_log_time
should_write = time_since_last_write >= self.log_every_n_seconds
else:
should_write = True

# Write logs if conditions are met
if should_write and self.log_queries:
self.db_connector.execute_queries(self.log_queries)
self.log_queries = []
self.last_log_time = time.time()


def _valid_logtable_logs(self, logs: Dict[str, Dict[str, str]]) -> bool:
logs = {f"{self.database_config.table_name}__{logtable_name}": logtable_entries for logtable_name, logtable_entries in logs.items()}
Expand Down
30 changes: 29 additions & 1 deletion test/test_logtables/test_mysql.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import logging
import os
from math import cos, sin
import os
import time
from unittest.mock import MagicMock, patch

from freezegun import freeze_time
from mock import MagicMock, call, patch
from omegaconf import OmegaConf

from py_experimenter import database_connector_mysql
from py_experimenter.config import DatabaseCfg
from py_experimenter.database_connector import DatabaseConnector
from py_experimenter.database_connector_mysql import DatabaseConnectorMYSQL
Expand Down Expand Up @@ -177,3 +180,28 @@ def test_integration_without_resultfields():
assert logtable2 == [(1, 1, 1), (2, 1, 3)]
assert timesteps == timesteps2
experimenter.close_ssh()

def own_function_without_resultfields_stagger(keyfields: dict, result_processor: ResultProcessor, custom_fields: dict):
result_processor.process_logs({"log": {"test": 0}})
result_processor.process_logs({"log": {"test": 2}})
time.sleep(10)
result_processor.process_logs({"log": {"test": 4}})


def test_stagger_logging():
experimenter = PyExperimenter(
os.path.join("test", "test_logtables", "mysql_logtables.yml"),
use_ssh_tunnel=False,
stagger_logging=True,
log_every_n_seconds=10,
)
try:
experimenter.delete_table()
except Exception:
pass

experimenter.fill_table_from_config()
experimenter.execute(own_function_without_resultfields_stagger, max_experiments=1)

logtable = experimenter.get_logtable("log")
assert (logtable.groupby("timestamp").count()["test"] == [2,1]).all()
28 changes: 27 additions & 1 deletion test/test_logtables/test_sqlite.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import os
from math import cos, sin
import os
import time

from freezegun import freeze_time
from mock import MagicMock, call, patch
Expand Down Expand Up @@ -155,3 +156,28 @@ def test_integration_without_resultfields():
non_timesteps_2 = [x[:2] + x[3:] for x in logtable2]
assert non_timesteps_2 == [(1, 1, 1), (2, 1, 3)]
assert timesteps == timesteps_2

def own_function_without_resultfields_stagger(keyfields: dict, result_processor: ResultProcessor, custom_fields: dict):
result_processor.process_logs({"log": {"test": 0}})
result_processor.process_logs({"log": {"test": 2}})
time.sleep(10)
result_processor.process_logs({"log": {"test": 4}})


def test_stagger_logging():
experimenter = PyExperimenter(
os.path.join("test", "test_logtables", "sqlite_logtables.yml"),
use_ssh_tunnel=False,
stagger_logging=True,
log_every_n_seconds=10,
)
try:
experimenter.delete_table()
except Exception:
pass

experimenter.fill_table_from_config()
experimenter.execute(own_function_without_resultfields_stagger, max_experiments=1)

logtable = experimenter.get_logtable("log")
assert (logtable.groupby("timestamp").count()["test"] == [2,1]).all()
Loading