Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add training for MACE #127

Merged
merged 4 commits into from
May 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions docs/source/apidoc/janus_core.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,16 @@ janus\_core.cli.md module
:undoc-members:
:show-inheritance:

janus\_core.cli.train module
----------------------------

.. automodule:: janus_core.cli.train
:members:
:special-members:
:private-members:
:undoc-members:
:show-inheritance:

janus\_core.cli.types module
----------------------------

Expand Down Expand Up @@ -104,6 +114,16 @@ janus\_core.helpers.mlip\_calculators module
:undoc-members:
:show-inheritance:

janus\_core.helpers.train module
--------------------------------

.. automodule:: janus_core.helpers.train
:members:
:special-members:
:private-members:
:undoc-members:
:show-inheritance:

janus\_core.calculations.single\_point module
---------------------------------------------

Expand Down
3 changes: 3 additions & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@
}
numpydoc_class_members_toctree = False

# Mock import of MACE module to avoid breaking build
autodoc_mock_imports = ["mace.cli.run_train"]

intersphinx_mapping = {
"python": ("https://docs.python.org/3", None),
"numpy": ("https://numpy.org/doc/stable/", None),
Expand Down
7 changes: 7 additions & 0 deletions janus_core/cli/janus.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,13 @@
app.command()(singlepoint)
app.command()(geomopt)
app.command()(md)
# Train not imlpemented in older versions of MACE
try:
from janus_core.cli.train import train

app.command()(train)
except NotImplementedError:
pass


@app.callback(invoke_without_command=True, help="")
Expand Down
25 changes: 25 additions & 0 deletions janus_core/cli/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""Set up MLIP training commandline interface."""

from pathlib import Path
from typing import Annotated

from typer import Option, Typer

from janus_core.helpers.train import train as run_train

app = Typer()


@app.command(help="Perform single point calculations and save to file.")
def train(
mlip_config: Annotated[Path, Option(help="Configuration file to pass to MLIP CLI.")]
):
"""
Run training for MLIP by passing a configuration file to the MLIP's CLI.

Parameters
----------
mlip_config : Path
Configuration file to pass to MLIP CLI.
"""
run_train(mlip_config)
69 changes: 69 additions & 0 deletions janus_core/helpers/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
"""Train MLIP."""

from pathlib import Path
from typing import Optional

try:
from mace.cli.run_train import run as run_train
except ImportError as e:
raise NotImplementedError("Please update MACE to use this module.") from e
from mace.tools import build_default_arg_parser as mace_parser
import yaml

from janus_core.helpers.janus_types import PathLike


def check_files_exist(config: dict, req_file_keys: list[PathLike]) -> None:
"""
Check files specified in the MLIP configuration file exist.

Parameters
----------
config : dict
MLIP configuration file options.
req_file_keys : list[Pathlike]
List of files that must exist if defined in the configuration file.

Raises
------
FileNotFoundError
If a key from `req_file_keys` is in the configuration file, but the
file corresponding to the configuration value do not exist.
"""
for file_key in req_file_keys:
# Only check if file key is in the configuration file
if file_key in config and not Path(config[file_key]).exists():
raise FileNotFoundError(f"{config[file_key]} does not exist")


def train(
mlip_config: PathLike, req_file_keys: Optional[list[PathLike]] = None
) -> None:
"""
Run training for MLIP by passing a configuration file to the MLIP's CLI.

Currently only supports MACE models, but this can be extended by replacing the
argument parsing.

Parameters
----------
mlip_config : PathLike
Configuration file to pass to MLIP.
req_file_keys : Optional[list[PathLike]]
List of files that must exist if defined in the configuration file.
Default is ["train_file", "test_file", "valid_file", "statistics_file"].
"""
if req_file_keys is None:
req_file_keys = ["train_file", "test_file", "valid_file", "statistics_file"]

# Validate inputs
with open(mlip_config, encoding="utf8") as file:
options = yaml.safe_load(file)
check_files_exist(options, req_file_keys)

if "foundation_model" in options:
alinelena marked this conversation as resolved.
Show resolved Hide resolved
print(f"Fine tuning model: {options['foundation_model']}")

# Path must be passed as a string
mlip_args = mace_parser().parse_args(["--config", str(mlip_config)])
run_train(mlip_args)