Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
dswigh committed Jun 24, 2024
1 parent 3b368ff commit 002032a
Showing 1 changed file with 16 additions and 11 deletions.
27 changes: 16 additions & 11 deletions condition_prediction/condition_prediction/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,14 +727,19 @@ def run_model(
pred_model, test_dataset, encoders
)
if save_test_predictions:
predictions_transformed = np.hstack([
encoder.inverse_transform(prediction_col)
for encoder, prediction_col
in zip(encoders, predictions)
])
predictions_df = pd.DataFrame(predictions_transformed, columns=molecule_columns)
predictions_df.to_parquet(output_folder_path / "test_predictions.parquet")

predictions_transformed = np.hstack(
[
encoder.inverse_transform(prediction_col)
for encoder, prediction_col in zip(encoders, predictions)
]
)
predictions_df = pd.DataFrame(
predictions_transformed, columns=molecule_columns
)
predictions_df.to_parquet(
output_folder_path / "test_predictions.parquet"
)

test_metrics_dict.update({"test_best": metrics})
model.load_weights(last_checkpoint_filepath)
update_teacher_forcing_model_weights(
Expand Down Expand Up @@ -844,14 +849,14 @@ def run_model(
default=False,
type=bool,
show_default=True,
help="Skip training and only do evaluation"
help="Skip training and only do evaluation",
)
@click.option(
"--save_test_predictions",
default=False,
type=bool,
show_default=True,
help="Skip training and only do evaluation"
help="Skip training and only do evaluation",
)
@click.option(
"--workers",
Expand Down Expand Up @@ -1108,7 +1113,7 @@ def main(
train_mode: int,
early_stopping_patience: int,
evaluate_on_test_data: bool,
skip_training: bool,
skip_training: bool,
save_test_predictions: bool,
generate_fingerprints: bool,
workers: int,
Expand Down

0 comments on commit 002032a

Please sign in to comment.