Skip to content

Commit

Permalink
Add internal progress bars using tqdm
Browse files Browse the repository at this point in the history
  • Loading branch information
Spinachboul committed Jun 19, 2024
1 parent 22662f5 commit 8f41c93
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 258 deletions.
92 changes: 43 additions & 49 deletions sktime/benchmarking/_lib_mini_kotsu/run.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
"""Interface for running a registry of models on a registry of validations."""

import functools
import logging
import os
import time
from typing import List, Optional, Union

import pandas as pd
# from tqdm import tqdm

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -59,62 +59,56 @@ def run(
results_df = pd.DataFrame(columns=["validation_id", "model_id", "runtime_secs"])
results_df["runtime_secs"] = results_df["runtime_secs"].astype(int)

from tqdm import tqdm

results_df = results_df.set_index(["validation_id", "model_id"], drop=False)
results_list = []

validations = list(validation_registry.all())
models = list(model_registry.all())

with tqdm(total=len(validations) * len(models), desc="Running validations") as pbar:
for validation_spec in validations:
if validation_spec.deprecated:
logger.info(
f"Skipping validation: {validation_spec.id} - as is deprecated."
)
pbar.update(len(models)) # Skip all models for this validation
print("Printing Validations...")
print("\n")
for validation_spec in tqdm(validation_registry.all()):
if validation_spec.deprecated:
logger.info(
f"Skipping validation: {validation_spec.id} - as is deprecated."
)
continue
print("Running models...")
print("\n")
for model_spec in tqdm(model_registry.all()):
if model_spec.deprecated:
logger.info(f"Skipping model: {model_spec.id} - as is deprecated.")
continue
for model_spec in models:
if model_spec.deprecated:
logger.info(f"Skipping model: {model_spec.id} - as is deprecated.")
pbar.update(1)
continue

if (
not force_rerun == "all"
and not (
isinstance(force_rerun, list) and model_spec.id in force_rerun
)
and (validation_spec.id, model_spec.id) in results_df.index
):
logger.info(
f"Skipping validation - model: "
f"{validation_spec.id} - {model_spec.id}"
", as found prior result in results."
)
pbar.update(1)
continue

if (
not force_rerun == "all"
and not (isinstance(force_rerun, list) and model_spec.id in force_rerun)
and (validation_spec.id, model_spec.id) in results_df.index
):
logger.info(
f"Running validation - model:{validation_spec.id}-{model_spec.id}"
)

validation = validation_spec.make()
validation = _form_validation_partial_with_store_dirs(
validation,
artefacts_store_dir,
validation_spec,
model_spec,
f"Skipping validation - model: "
f"{validation_spec.id} - {model_spec.id}"
", as found prior result in results."
)
continue

model = model_spec.make()
results, elapsed_secs = _run_validation_model(
validation, model, run_params
)
results = _add_meta_data_to_results(
results, elapsed_secs, validation_spec, model_spec
)
results_list.append(results)
pbar.update(1)
logger.info(
f"Running validation - model: {validation_spec.id} - {model_spec.id}"
)

validation = validation_spec.make()
validation = _form_validation_partial_with_store_dirs(
validation,
artefacts_store_dir,
validation_spec,
model_spec,
)

model = model_spec.make()
results, elapsed_secs = _run_validation_model(validation, model, run_params)
results = _add_meta_data_to_results(
results, elapsed_secs, validation_spec, model_spec
)
results_list.append(results)

additional_results_df = pd.DataFrame.from_records(results_list)
results_df = pd.concat([results_df, additional_results_df], ignore_index=True)
Expand Down
209 changes: 0 additions & 209 deletions sktime/benchmarking/_lib_mini_kotsu/run1.py

This file was deleted.

0 comments on commit 8f41c93

Please sign in to comment.