Skip to content

Commit

Permalink
Update condition_prediction analysis
Browse files Browse the repository at this point in the history
  • Loading branch information
marcosfelt committed Feb 7, 2024
1 parent ebd5991 commit 8259840
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 336 deletions.
58 changes: 43 additions & 15 deletions condition_prediction/condition_prediction/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ class ConditionPrediction:
verbosity: int = 2
random_seed: int = 12345
skip_training: bool = False
save_test_predictions: bool = False
dataset_version: str = "v5"

def __post_init__(self) -> None:
Expand Down Expand Up @@ -154,6 +155,7 @@ def run_model_arguments(self) -> None:
interleave=self.interleave,
early_stopping_patience=self.early_stopping_patience,
evaluate_on_test_data=self.evaluate_on_test_data,
save_test_predictions=self.save_test_predictions,
wandb_project=self.wandb_project,
wandb_entity=self.wandb_entity,
wandb_logging=self.wandb_logging,
Expand Down Expand Up @@ -311,7 +313,7 @@ def evaluate_model(model, dataset, encoders):
).all(axis=1)
metrics["overall_accuracy_top3"] = np.mean(overall_scores_top3)

return metrics
return metrics, predictions

@staticmethod
def run_model(
Expand All @@ -326,6 +328,7 @@ def run_model(
epochs: int = 20,
early_stopping_patience: int = 5,
evaluate_on_test_data: bool = False,
save_test_predictions: bool = False,
train_mode: int = HARD_SELECTION,
batch_size: int = 512,
fp_size: int = 2048,
Expand Down Expand Up @@ -674,7 +677,7 @@ def run_model(
{
"val_last_epoch": ConditionPrediction.evaluate_model(
pred_model, val_dataset, encoders
)
)[0]
}
)

Expand All @@ -687,7 +690,7 @@ def run_model(
{
"val_best": ConditionPrediction.evaluate_model(
pred_model, val_dataset, encoders
)
)[0]
}
)

Expand Down Expand Up @@ -720,24 +723,27 @@ def run_model(
update_teacher_forcing_model_weights(
update_model=pred_model, to_copy_model=model
)
test_metrics_dict.update(
{
"test_best": ConditionPrediction.evaluate_model(
pred_model, test_dataset, encoders
)
}
metrics, predictions = ConditionPrediction.evaluate_model(
pred_model, test_dataset, encoders
)
if save_test_predictions:
predictions_transformed = np.hstack([
encoder.inverse_transform(prediction_col)
for encoder, prediction_col
in zip(encoders, predictions)
])
predictions_df = pd.DataFrame(predictions_transformed, columns=molecule_columns)
predictions_df.to_parquet(output_folder_path / "test_predictions.parquet")

test_metrics_dict.update({"test_best": metrics})
model.load_weights(last_checkpoint_filepath)
update_teacher_forcing_model_weights(
update_model=pred_model, to_copy_model=model
)
test_metrics_dict.update(
{
"test_last_epoch": ConditionPrediction.evaluate_model(
pred_model, test_dataset, encoders
)
}
metrics, predictions = ConditionPrediction.evaluate_model(
pred_model, test_dataset, encoders
)
test_metrics_dict.update({"test_last_epoch": metrics})

# Save the test metrics
test_metrics_file_path = output_folder_path / "test_metrics.json"
Expand Down Expand Up @@ -833,6 +839,20 @@ def run_model(
show_default=True,
help="If True, will generate fingerprints on the fly instead of loading them from memory",
)
@click.option(
"--skip_training",
default=False,
type=bool,
show_default=True,
help="Skip training and only do evaluation"
)
@click.option(
"--save_test_predictions",
default=False,
type=bool,
show_default=True,
help="Skip training and only do evaluation"
)
@click.option(
"--workers",
default=0,
Expand Down Expand Up @@ -988,6 +1008,8 @@ def main_click(
train_mode: int,
early_stopping_patience: int,
evaluate_on_test_data: bool,
skip_training: bool,
save_test_predictions: bool,
generate_fingerprints: bool,
workers: int,
fp_size: int,
Expand Down Expand Up @@ -1044,6 +1066,8 @@ def main_click(
train_mode=train_mode,
early_stopping_patience=early_stopping_patience,
evaluate_on_test_data=evaluate_on_test_data,
skip_training=skip_training,
save_test_predictions=save_test_predictions,
generate_fingerprints=generate_fingerprints,
workers=workers,
fp_size=fp_size,
Expand Down Expand Up @@ -1084,6 +1108,8 @@ def main(
train_mode: int,
early_stopping_patience: int,
evaluate_on_test_data: bool,
skip_training: bool,
save_test_predictions: bool,
generate_fingerprints: bool,
workers: int,
fp_size: int,
Expand Down Expand Up @@ -1207,6 +1233,8 @@ def main(
train_mode=train_mode,
early_stopping_patience=early_stopping_patience,
evaluate_on_test_data=evaluate_on_test_data,
skip_training=skip_training,
save_test_predictions=save_test_predictions,
wandb_entity=wandb_entity,
wandb_project=wandb_project,
wandb_logging=wandb_logging,
Expand Down
360 changes: 39 additions & 321 deletions notebooks/plot_model_performance_wandb.ipynb

Large diffs are not rendered by default.

0 comments on commit 8259840

Please sign in to comment.