Skip to content

Commit

Permalink
feat: add do_truncate control for the load function (#656)
Browse files Browse the repository at this point in the history
feat: add do_truncate control for the load function (#656)
  • Loading branch information
RolandMinrui authored Mar 4, 2025
1 parent 7355174 commit 2b960a5
Showing 2 changed files with 14 additions and 6 deletions.
10 changes: 7 additions & 3 deletions rdagent/app/data_science/loop.py
Original file line number Diff line number Diff line change
@@ -142,13 +142,15 @@ def record(self, prev_out: dict[str, Any]):
logger.log_object(self.trace.sota_experiment(), tag="SOTA experiment")


def main(path=None, output_path=None, step_n=None, loop_n=None, competition="bms-molecular-translation"):
def main(
path=None, output_path=None, step_n=None, loop_n=None, competition="bms-molecular-translation", do_truncate=True
):
"""
Parameters
----------
path :
path like `$LOG_PATH/__session__/1/0_propose`. It indicates that we restore the state that after finish the step 0 in loop1
path like `$LOG_PATH/__session__/1/0_propose`. It indicates that we restore the state that after finish the step 0 in loop 1
output_path :
path like `$LOG_PATH`. It indicates that where we want to save our session and log information.
step_n :
@@ -158,6 +160,8 @@ def main(path=None, output_path=None, step_n=None, loop_n=None, competition="bms
- if current loop is incomplete, it will be counted as the first loop for completion.
- if both step_n and loop_n are provided, the process will stop as soon as either condition is met.
competition :
do_truncate :
If set to True, the logger will truncate the future log messages by calling `logger.storage.truncate`.
Auto R&D Evolving loop for models in a Kaggle scenario.
@@ -181,7 +185,7 @@ def main(path=None, output_path=None, step_n=None, loop_n=None, competition="bms
if path is None:
kaggle_loop = DataScienceRDLoop(DS_RD_SETTING)
else:
kaggle_loop = DataScienceRDLoop.load(path, output_path)
kaggle_loop = DataScienceRDLoop.load(path, output_path, do_truncate)
kaggle_loop.run(step_n=step_n, loop_n=loop_n)


10 changes: 7 additions & 3 deletions rdagent/utils/workflow.py
Original file line number Diff line number Diff line change
@@ -161,7 +161,9 @@ def dump(self, path: str | Path) -> None:
pickle.dump(self, f)

@classmethod
def load(cls, path: Union[str, Path], output_path: Optional[Union[str, Path]] = None) -> "LoopBase":
def load(
cls, path: Union[str, Path], output_path: Optional[Union[str, Path]] = None, do_truncate: bool = True
) -> "LoopBase":
path = Path(path)
with path.open("rb") as f:
session = cast(LoopBase, pickle.load(f))
@@ -175,8 +177,10 @@ def load(cls, path: Union[str, Path], output_path: Optional[Union[str, Path]] =
# set trace path
logger.set_trace_path(session.session_folder.parent)

max_loop = max(session.loop_trace.keys())
logger.storage.truncate(time=session.loop_trace[max_loop][-1].end)
# truncate future message
if do_truncate:
max_loop = max(session.loop_trace.keys())
logger.storage.truncate(time=session.loop_trace[max_loop][-1].end)
return session


0 comments on commit 2b960a5

Please sign in to comment.