diff --git a/src/dvclive/fastai.py b/src/dvclive/fastai.py index 8d7f9e6e..f62299e2 100644 --- a/src/dvclive/fastai.py +++ b/src/dvclive/fastai.py @@ -30,7 +30,7 @@ def __init__( model_file: Optional[str] = None, with_opt: bool = False, live: Optional[Live] = None, - **kwargs + **kwargs, ): super().__init__() self.model_file = model_file @@ -38,7 +38,22 @@ def __init__( self.live = live if live is not None else Live(**kwargs) self.freeze_stage_ended = False + def before_fit(self): + if hasattr(self, "lr_finder") or hasattr(self, "gather_preds"): + return + params = { + "model": type(self.learn.model).__qualname__, + "batch_size": getattr(self.dls, "bs", None), + "batch_per_epoch": len(getattr(self.dls, "train", [])), + "frozen": bool(getattr(self.opt, "frozen_idx", -1)), + "frozen_idx": getattr(self.opt, "frozen_idx", -1), + "transforms": f"{getattr(self.dls, 'tfms', None)}", + } + self.live.log_params(params) + def after_epoch(self): + if hasattr(self, "lr_finder") or hasattr(self, "gather_preds"): + return logged_metrics = False for key, value in zip( self.learn.recorder.metric_names, self.learn.recorder.log diff --git a/tests/test_frameworks/test_fastai.py b/tests/test_frameworks/test_fastai.py index 86c975ec..e0035013 100644 --- a/tests/test_frameworks/test_fastai.py +++ b/tests/test_frameworks/test_fastai.py @@ -48,7 +48,12 @@ def test_fastai_callback(tmp_dir, data_loader, mocker): learn.fit_one_cycle(2, cbs=[callback]) spy.assert_called_once() - assert os.path.exists(live.dir) + assert (tmp_dir / live.dir).exists() + assert (tmp_dir / live.params_file).exists() + assert (tmp_dir / live.params_file).read_text() == ( + "model: TabularModel\nbatch_size: 2\nbatch_per_epoch: 2\nfrozen: false" + "\nfrozen_idx: 0\ntransforms: None\n" + ) metrics_path = tmp_dir / live.plots_dir / Metric.subfolder train_path = metrics_path / "train"