Skip to content

Commit

Permalink
121 add training (#123)
Browse files Browse the repository at this point in the history
* add train for MLIPs, MACE only

---------

Co-authored-by: ElliottKasoar <45317199+ElliottKasoar@users.noreply.github.com>
Co-authored-by: Jacob Wilkins <46597752+oerc0122@users.noreply.github.com>
  • Loading branch information
3 people authored May 23, 2024
1 parent 18695f7 commit 3ffb47d
Show file tree
Hide file tree
Showing 17 changed files with 3,038 additions and 5 deletions.
12 changes: 11 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ machine learning interatomic potentials aiida plugin
- NVE
- NVT (Langevin(Eijnden/Ciccotti flavour) and Nosé-Hoover (Melchionna flavour))
- NPT (Nosé-Hoover (Melchiona flavour))
- [ ] Training ML potentials (MACE only planned)
- [x] Training ML potentials (MACE only planned)
- [ ] Fine tunning MLIPs (MACE only planned)

The code relies heavily on [janus-core](https://github.com/stfc/janus-core), which handles mlip calculations using ASE.
Expand All @@ -45,6 +45,7 @@ Registered entry points for aiida.calculations:
* janus.opt
* janus.sp
* janus.md
* janus.train
```


Expand All @@ -60,6 +61,11 @@ verdi run submit_md.py "janus@localhost" --struct "path/to/structure" --model "p

verdi process list -a # check record of calculation
```
Models can be trained by using the Train calcjob. In that case the needed inputs are a config file containig the path to train, test and validation xyz file and other optional parameters. Running
```shell
verdi run submit_train.py
```
a model will be trained using the provided example config file and xyz files (can be found in the tests folder)

## Development

Expand Down Expand Up @@ -97,6 +103,7 @@ See the [developer guide](https://stfc.github.io/aiida-mlip/developer_guide/inde
* [`sp_parser.py`](aiida_mlip/parsers/sp_parser.py): `Parser` for `Singlepoint` calculation.
* [`opt_parser.py`](aiida_mlip/parsers/opt_parser.py): `Parser` for `Geomopt` calculation.
* [`md_parser.py`](aiida_mlip/parsers/md_parser.py): `Parser` for `MD` calculation.
* [`train_parser.py`](aiida_mlip/parsers/train_parser.py): `Parser` for `Train` calculation.
* [`helpers/`](aiida_mlip/helpers/): `Helpers` to run calculations.
* [`docs/`](docs/source/): Code documentation
* [`apidoc/`](docs/source/apidoc/): API documentation
Expand All @@ -108,14 +115,17 @@ See the [developer guide](https://stfc.github.io/aiida-mlip/developer_guide/inde
* [`submit_singlepoint.py`](examples/calculations/submit_singlepoint.py): Script for submitting a singlepoint calculation
* [`submit_geomopt.py`](examples/calculations/submit_geomopt.py): Script for submitting a geometry optimisation calculation
* [`submit_md.py`](examples/calculations/submit_md.py): Script for submitting a molecular dynamics calculation
* [`submit_train.py`](examples/calculations/submit_train.py): Script for submitting a train calculation.
* [`tests/`](tests/): Basic regression tests using the [pytest](https://docs.pytest.org/en/latest/) framework (submitting a calculation, ...). Install `pip install -e .[testing]` and run `pytest`.
* [`conftest.py`](tests/conftest.py): Configuration of fixtures for [pytest](https://docs.pytest.org/en/latest/)
* [`calculations/`](tests/calculations): Calculations
* [`test_singlepoint.py`](tests/calculations/test_singlepoint.py): Test `SinglePoint` calculation
* [`test_geomopt.py`](tests/calculations/test_geomopt.py): Test `Geomopt` calculation
* [`test_md.py`](tests/calculations/test_md.py): Test `MD` calculation
* [`test_train.py`](tests/calculations/test_train.py): Test `Train` calculation
* [`data/`](tests/data): `ModelData`
* [`test_model.py`](tests/data/test_model.py): Test `ModelData` type
* [`test_config.py`](tests/data/test_config.py): Test `JanusConfigfile` type
* [`.gitignore`](.gitignore): Telling git which files to ignore
* [`.pre-commit-config.yaml`](.pre-commit-config.yaml): Configuration of [pre-commit hooks](https://pre-commit.com/) that sanitize coding style and check for syntax errors. Enable via `pip install -e .[pre-commit] && pre-commit install`
* [`LICENSE`](LICENSE): License for the plugin
Expand Down
190 changes: 190 additions & 0 deletions aiida_mlip/calculations/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
"""Class for training machine learning models."""

from pathlib import Path

from aiida.common import InputValidationError, datastructures
import aiida.common.folders
from aiida.engine import CalcJob, CalcJobProcessSpec
import aiida.engine.processes
from aiida.orm import Dict, FolderData, SinglefileData

from aiida_mlip.data.config import JanusConfigfile
from aiida_mlip.data.model import ModelData


def validate_inputs(
inputs: dict, port_namespace: aiida.engine.processes.ports.PortNamespace
):
"""
Check if the inputs are valid.
Parameters
----------
inputs : dict
The inputs dictionary.
port_namespace : `aiida.engine.processes.ports.PortNamespace`
An instance of aiida's `PortNameSpace`.
Raises
------
InputValidationError
Error message if validation fails, None otherwise.
"""
if "mlip_config" in port_namespace:
# Check if a config file is given
if "mlip_config" not in inputs:
raise InputValidationError("No config file given")
config_file = inputs["mlip_config"]
# Check if 'name' keyword is given
if "name" not in config_file:
raise InputValidationError("key 'name' must be defined in the config file")
# Check if the xyz files paths are given
required_keys = ("train_file", "valid_file", "test_file")
for key in required_keys:
if key not in config_file:
raise InputValidationError(f"Mandatory key {key} not in config file")
# Check if the keys actually correspond to a path
if not ((Path(config_file.as_dictionary[key])).resolve()).exists():
raise InputValidationError(f"Path given for {key} does not exist")


class Train(CalcJob): # numpydoc ignore=PR01
"""
Calcjob implementation to train mlips.
Attributes
----------
DEFAULT_OUTPUT_FILE : str
Default stdout file name.
Methods
-------
define(spec: CalcJobProcessSpec) -> None:
Define the process specification, its inputs, outputs and exit codes.
validate_inputs(value: dict, port_namespace: PortNamespace) -> Optional[str]:
Check if the inputs are valid.
prepare_for_submission(folder: Folder) -> CalcInfo:
Create the input files for the `CalcJob`.
"""

DEFAULT_OUTPUT_FILE = "aiida-stdout.txt"

@classmethod
def define(cls, spec: CalcJobProcessSpec) -> None:
"""
Define the process specification, its inputs, outputs and exit codes.
Parameters
----------
spec : `aiida.engine.CalcJobProcessSpec`
The calculation job process spec to define.
"""
super().define(spec)

# Define inputs
spec.input(
"mlip_config",
valid_type=JanusConfigfile,
required=True,
help="Config file with parameters for training",
)
spec.input(
"metadata.options.output_filename",
valid_type=str,
default=cls.DEFAULT_OUTPUT_FILE,
)

spec.input(
"metadata.options.scheduler_stdout",
valid_type=str,
default="_scheduler-stdout.txt",
help="Filename to which the content of stdout of the scheduler is written.",
)
spec.inputs["metadata"]["options"]["parser_name"].default = "janus.train_parser"
spec.inputs.validator = validate_inputs
spec.output("model", valid_type=ModelData)
spec.output("compiled_model", valid_type=SinglefileData)
spec.output(
"results_dict",
valid_type=Dict,
help="The `results_dict` output node of the training.",
)
spec.output("logs", valid_type=FolderData)
spec.output("checkpoints", valid_type=FolderData)
spec.default_output_node = "results_dict"
# Exit codes
spec.exit_code(
305,
"ERROR_MISSING_OUTPUT_FILES",
message="Some output files missing or cannot be read",
)

# pylint: disable=too-many-locals
def prepare_for_submission(
self, folder: aiida.common.folders.Folder
) -> datastructures.CalcInfo:
"""
Create the input files for the `Calcjob`.
Parameters
----------
folder : aiida.common.folders.Folder
An `aiida.common.folders.Folder` to temporarily write files on disk.
Returns
-------
aiida.common.datastructures.CalcInfo
An instance of `aiida.common.datastructures.CalcInfo`.
"""
# The config file needs to be copied in the working folder
# Read content
mlip_dict = self.inputs.mlip_config.as_dictionary
config_parse = self.inputs.mlip_config.get_content()

# Extract paths from the config
for file in ("train_file", "test_file", "valid_file"):
abs_path = Path(mlip_dict[file]).resolve()

# Update the config file with absolute paths
config_parse = config_parse.replace(mlip_dict[file], str(abs_path))
# Copy config file content inside the folder where the calculation is run
config_copy = "mlip_train.yml"
with folder.open(config_copy, "w", encoding="utf-8") as configfile:
configfile.write(config_parse)

codeinfo = datastructures.CodeInfo()

# Initialize cmdline_params with train command
codeinfo.cmdline_params = ["train"]
# Create the rest of the command line
cmd_line = {}
cmd_line["mlip-config"] = config_copy
# Add cmd line params to codeinfo
for flag, value in cmd_line.items():
codeinfo.cmdline_params += [f"--{flag}", str(value)]

# Node where the code is saved
codeinfo.code_uuid = self.inputs.code.uuid
# Save name of output as you need it for running the code
codeinfo.stdout_name = self.metadata.options.output_filename

calcinfo = datastructures.CalcInfo()
calcinfo.codes_info = [codeinfo]
# Save the info about the node where the calc is stored
calcinfo.uuid = str(self.uuid)
# Retrieve output files
model_dir = Path(mlip_dict.get("model_dir", "."))
model_output = model_dir / f"{mlip_dict['name']}.model"
compiled_model_output = model_dir / f"{mlip_dict['name']}_compiled.model"
calcinfo.retrieve_list = [
self.metadata.options.output_filename,
self.uuid,
mlip_dict.get("log_dir", "logs"),
mlip_dict.get("result_dir", "results"),
mlip_dict.get("checkpoint_dir", "checkpoints"),
str(model_output),
str(compiled_model_output),
]

return calcinfo
6 changes: 3 additions & 3 deletions aiida_mlip/parsers/base_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class BaseParser(Parser):
Methods
-------
__init__(node: aiida.orm.nodes.process.process.ProcessNode)
Initialize the SPParser instance.
Initialize the BaseParser instance.
parse(**kwargs: Any) -> int:
Parse outputs, store results in the database.
Expand All @@ -33,12 +33,12 @@ class BaseParser(Parser):
Raises
------
exceptions.ParsingError
If the ProcessNode being passed was not produced by a singlePointCalculation.
If the ProcessNode being passed was not produced by a `Base` Calcjob.
"""

def __init__(self, node: ProcessNode):
"""
Check that the ProcessNode being passed was produced by a `Singlepoint`.
Check that the ProcessNode being passed was produced by a `Base` Calcjob.
Parameters
----------
Expand Down
Loading

0 comments on commit 3ffb47d

Please sign in to comment.