Skip to content

Commit

Permalink
Add training for MACE (#127)
Browse files Browse the repository at this point in the history
* Add initial training/finetuning, supports only mace and misses tests, due to waiting for mace release.

---------

Co-authored-by: ElliottKasoar <ElliottKasoar@users.noreply.github.com>
  • Loading branch information
ElliottKasoar and ElliottKasoar committed May 4, 2024
1 parent 41bf1f4 commit 36aeb49
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 0 deletions.
20 changes: 20 additions & 0 deletions docs/source/apidoc/janus_core.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,16 @@ janus\_core.cli.md module
:undoc-members:
:show-inheritance:

janus\_core.cli.train module
----------------------------

.. automodule:: janus_core.cli.train
:members:
:special-members:
:private-members:
:undoc-members:
:show-inheritance:

janus\_core.cli.types module
----------------------------

Expand Down Expand Up @@ -104,6 +114,16 @@ janus\_core.helpers.mlip\_calculators module
:undoc-members:
:show-inheritance:

janus\_core.helpers.train module
--------------------------------

.. automodule:: janus_core.helpers.train
:members:
:special-members:
:private-members:
:undoc-members:
:show-inheritance:

janus\_core.calculations.single\_point module
---------------------------------------------

Expand Down
3 changes: 3 additions & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@
}
numpydoc_class_members_toctree = False

# Mock import of MACE module to avoid breaking build
autodoc_mock_imports = ["mace.cli.run_train"]

intersphinx_mapping = {
"python": ("https://docs.python.org/3", None),
"numpy": ("https://numpy.org/doc/stable/", None),
Expand Down
7 changes: 7 additions & 0 deletions janus_core/cli/janus.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,13 @@
app.command()(singlepoint)
app.command()(geomopt)
app.command()(md)
# Train not imlpemented in older versions of MACE
try:
from janus_core.cli.train import train

app.command()(train)
except NotImplementedError:
pass


@app.callback(invoke_without_command=True, help="")
Expand Down
25 changes: 25 additions & 0 deletions janus_core/cli/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""Set up MLIP training commandline interface."""

from pathlib import Path
from typing import Annotated

from typer import Option, Typer

from janus_core.helpers.train import train as run_train

app = Typer()


@app.command(help="Perform single point calculations and save to file.")
def train(
mlip_config: Annotated[Path, Option(help="Configuration file to pass to MLIP CLI.")]
):
"""
Run training for MLIP by passing a configuration file to the MLIP's CLI.
Parameters
----------
mlip_config : Path
Configuration file to pass to MLIP CLI.
"""
run_train(mlip_config)
69 changes: 69 additions & 0 deletions janus_core/helpers/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
"""Train MLIP."""

from pathlib import Path
from typing import Optional

try:
from mace.cli.run_train import run as run_train
except ImportError as e:
raise NotImplementedError("Please update MACE to use this module.") from e
from mace.tools import build_default_arg_parser as mace_parser
import yaml

from janus_core.helpers.janus_types import PathLike


def check_files_exist(config: dict, req_file_keys: list[PathLike]) -> None:
"""
Check files specified in the MLIP configuration file exist.
Parameters
----------
config : dict
MLIP configuration file options.
req_file_keys : list[Pathlike]
List of files that must exist if defined in the configuration file.
Raises
------
FileNotFoundError
If a key from `req_file_keys` is in the configuration file, but the
file corresponding to the configuration value do not exist.
"""
for file_key in req_file_keys:
# Only check if file key is in the configuration file
if file_key in config and not Path(config[file_key]).exists():
raise FileNotFoundError(f"{config[file_key]} does not exist")


def train(
mlip_config: PathLike, req_file_keys: Optional[list[PathLike]] = None
) -> None:
"""
Run training for MLIP by passing a configuration file to the MLIP's CLI.
Currently only supports MACE models, but this can be extended by replacing the
argument parsing.
Parameters
----------
mlip_config : PathLike
Configuration file to pass to MLIP.
req_file_keys : Optional[list[PathLike]]
List of files that must exist if defined in the configuration file.
Default is ["train_file", "test_file", "valid_file", "statistics_file"].
"""
if req_file_keys is None:
req_file_keys = ["train_file", "test_file", "valid_file", "statistics_file"]

# Validate inputs
with open(mlip_config, encoding="utf8") as file:
options = yaml.safe_load(file)
check_files_exist(options, req_file_keys)

if "foundation_model" in options:
print(f"Fine tuning model: {options['foundation_model']}")

# Path must be passed as a string
mlip_args = mace_parser().parse_args(["--config", str(mlip_config)])
run_train(mlip_args)

0 comments on commit 36aeb49

Please sign in to comment.