Skip to content

Commit

Permalink
Add skip for training if MACE import fails
Browse files Browse the repository at this point in the history
  • Loading branch information
ElliottKasoar committed May 9, 2024
1 parent bdb4cb0 commit 086dd66
Showing 1 changed file with 14 additions and 0 deletions.
14 changes: 14 additions & 0 deletions tests/test_train_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,32 @@
from pathlib import Path
import shutil

import pytest
from typer.testing import CliRunner

try:
from mace.cli.run_train import run as run_train # pylint: disable=unused-import

SKIP_TESTS = False
except ImportError:
SKIP_TESTS = True

from janus_core.cli.janus import app

DATA_PATH = Path(__file__).parent / "data"

runner = CliRunner()


@pytest.mark.skipif(SKIP_TESTS, reason="Requires updated version of MACE")
def test_help():
"""Test calling `janus train --help`."""
result = runner.invoke(app, ["train", "--help"])
assert result.exit_code == 0
assert "Usage: janus train [OPTIONS]" in result.stdout


@pytest.mark.skipif(SKIP_TESTS, reason="Requires updated version of MACE")
def test_train():
"""Test MLIP training."""
model = "test.model"
Expand Down Expand Up @@ -59,6 +69,7 @@ def test_train():
assert result.exit_code == 0


@pytest.mark.skipif(SKIP_TESTS, reason="Requires updated version of MACE")
def test_train_with_foundation():
"""Test MLIP training raises error with foundation_model in config."""
config = DATA_PATH / "mlip_train_invalid.yml"
Expand All @@ -75,6 +86,7 @@ def test_train_with_foundation():
assert isinstance(result.exception, ValueError)


@pytest.mark.skipif(SKIP_TESTS, reason="Requires updated version of MACE")
def test_fine_tune():
"""Test MLIP fine-tuning."""
model = "test-finetuned.model"
Expand Down Expand Up @@ -111,6 +123,7 @@ def test_fine_tune():
assert result.exit_code == 0


@pytest.mark.skipif(SKIP_TESTS, reason="Requires updated version of MACE")
def test_fine_tune_no_foundation():
"""Test MLIP fine-tuning raises errors without foundation_model."""
config = DATA_PATH / "mlip_fine_tune_no_foundation.yml"
Expand All @@ -123,6 +136,7 @@ def test_fine_tune_no_foundation():
assert isinstance(result.exception, ValueError)


@pytest.mark.skipif(SKIP_TESTS, reason="Requires updated version of MACE")
def test_fine_tune_invalid_foundation():
"""Test MLIP fine-tuning raises errors with invalid foundation_model."""
config = DATA_PATH / "mlip_fine_tune_invalid_foundation.yml"
Expand Down

0 comments on commit 086dd66

Please sign in to comment.