diff --git a/timm/utils/checkpoint_saver.py b/timm/utils/checkpoint_saver.py index 37765ee1d9..3f0aa535e1 100644 --- a/timm/utils/checkpoint_saver.py +++ b/timm/utils/checkpoint_saver.py @@ -107,7 +107,12 @@ def save_checkpoint(self, epoch, metric=None): "model_kwargs": self.args.model_kwargs, } torch.save(model_dict, temp_location) + torch.save( + get_state_dict(self.model, self.unwrap_fn), + os.path.join(temp_dir, "state_dict.pt"), + ) mlflow.log_artifact(temp_location) + mlflow.log_artifact(os.path.join(temp_dir, "state_dict.pt")) return (None, None) if self.best_metric is None else (self.best_metric, self.best_epoch)