diff --git a/src/pytorch_tabular/tabular_model_sweep.py b/src/pytorch_tabular/tabular_model_sweep.py index a2c18932..bf3a192f 100644 --- a/src/pytorch_tabular/tabular_model_sweep.py +++ b/src/pytorch_tabular/tabular_model_sweep.py @@ -359,27 +359,31 @@ def _init_tabular_model(m): res_dict["time_taken"] = time.time() - start_time res_dict["time_taken_per_epoch"] = res_dict["time_taken"] / res_dict["epochs"] - if verbose: - logger.info(f"Finished Training {name}") - logger.info("Results:" f" {', '.join([f'{k}: {v}' for k,v in res_dict.items()])}") - res_dict["params"] = params - results.append(res_dict) - if best_model is None: - best_model = tabular_model - else: - if is_lower_better: - if res_dict[f"test_{rank_metric[0]}"] < best_score: - best_model = tabular_model - best_score = res_dict[f"test_{rank_metric[0]}"] - else: - if res_dict[f"test_{rank_metric[0]}"] > best_score: - best_model = tabular_model + if verbose: + logger.info(f"Finished Training {name}") + logger.info("Results:" f" {', '.join([f'{k}: {v}' for k,v in res_dict.items()])}") + res_dict["params"] = params + results.append(res_dict) + if return_best_model: + tabular_model.datamodule = None + if best_model is None: + best_model = copy.deepcopy(tabular_model) best_score = res_dict[f"test_{rank_metric[0]}"] + else: + if is_lower_better: + if res_dict[f"test_{rank_metric[0]}"] < best_score: + best_model = copy.deepcopy(tabular_model) + best_score = res_dict[f"test_{rank_metric[0]}"] + else: + if res_dict[f"test_{rank_metric[0]}"] > best_score: + best_model = copy.deepcopy(tabular_model) + best_score = res_dict[f"test_{rank_metric[0]}"] if verbose: logger.info("Model Sweep Finished") logger.info(f"Best Model: {best_model.name}") results = pd.DataFrame(results).sort_values(by=f"test_{rank_metric[0]}", ascending=is_lower_better) - if return_best_model: + if return_best_model and best_model is not None: + best_model.datamodule = datamodule return results, best_model else: return results