Skip to content

Commit

Permalink
change parse structure cause too many variables
Browse files Browse the repository at this point in the history
  • Loading branch information
federicazanca committed May 17, 2024
1 parent 1820ee1 commit c2d531f
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 54 deletions.
34 changes: 12 additions & 22 deletions aiida_mlip/calculations/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,6 @@ class Train(CalcJob): # numpydoc ignore=PR01
----------
DEFAULT_OUTPUT_FILE : str
Default stdout file name.
DEFAULT_INPUT_FILE : str
Default input file name.
LOG_FILE : str
Default log file name.
Methods
-------
Expand All @@ -71,8 +67,6 @@ class Train(CalcJob): # numpydoc ignore=PR01
"""

DEFAULT_OUTPUT_FILE = "aiida-stdout.txt"
DEFAULT_INPUT_FILE = "aiida.xyz"
LOG_FILE = "aiida.log"

@classmethod
def define(cls, spec: CalcJobProcessSpec) -> None:
Expand All @@ -90,19 +84,15 @@ def define(cls, spec: CalcJobProcessSpec) -> None:
spec.input(
"mlip_config",
valid_type=JanusConfigfile,
required=False,
help="Mlip architecture to use for calculation, defaults to mace",
required=True,
help="Config file with parameters for training",
)
spec.input(
"metadata.options.output_filename",
valid_type=str,
default=cls.DEFAULT_OUTPUT_FILE,
)
spec.input(
"metadata.options.input_filename",
valid_type=str,
default=cls.DEFAULT_INPUT_FILE,
)

spec.input(
"metadata.options.scheduler_stdout",
valid_type=str,
Expand Down Expand Up @@ -145,10 +135,8 @@ def prepare_for_submission(
aiida.common.datastructures.CalcInfo
An instance of `aiida.common.datastructures.CalcInfo`.
"""
cmd_line = {}

cmd_line["mlip-config"] = "mlip_train.yml"

# The config file needs to be copied in the working folder
# Read content
mlip_dict = self.inputs.mlip_config.as_dictionary
config_parse = self.inputs.mlip_config.get_content()
# Extract paths from the config
Expand All @@ -172,15 +160,14 @@ def prepare_for_submission(
with folder.open("mlip_train.yml", "w", encoding="utf-8") as configfile:
configfile.write(config_parse)

model_dir = Path(mlip_dict.get("model_dir", "."))
model_output = model_dir / f"{mlip_dict['name']}.model"
compiled_model_output = model_dir / f"{mlip_dict['name']}_compiled.model"

codeinfo = datastructures.CodeInfo()

# Initialize cmdline_params with train command
codeinfo.cmdline_params = ["train"]

# Create the rest of the command line
cmd_line = {}
cmd_line["mlip-config"] = "mlip_train.yml"
# Add cmd line params to codeinfo
for flag, value in cmd_line.items():
codeinfo.cmdline_params += [f"--{flag}", str(value)]

Expand All @@ -194,6 +181,9 @@ def prepare_for_submission(
# Save the info about the node where the calc is stored
calcinfo.uuid = str(self.uuid)
# Retrieve output files
model_dir = Path(mlip_dict.get("model_dir", "."))
model_output = model_dir / f"{mlip_dict['name']}.model"
compiled_model_output = model_dir / f"{mlip_dict['name']}_compiled.model"
calcinfo.retrieve_list = [
self.metadata.options.output_filename,
self.uuid,
Expand Down
6 changes: 3 additions & 3 deletions aiida_mlip/parsers/base_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class BaseParser(Parser):
Methods
-------
__init__(node: aiida.orm.nodes.process.process.ProcessNode)
Initialize the SPParser instance.
Initialize the BaseParser instance.
parse(**kwargs: Any) -> int:
Parse outputs, store results in the database.
Expand All @@ -33,12 +33,12 @@ class BaseParser(Parser):
Raises
------
exceptions.ParsingError
If the ProcessNode being passed was not produced by a singlePointCalculation.
If the ProcessNode being passed was not produced by a `Base` Calcjob.
"""

def __init__(self, node: ProcessNode):
"""
Check that the ProcessNode being passed was produced by a `Singlepoint`.
Check that the ProcessNode being passed was produced by a `Base` Calcjob.
Parameters
----------
Expand Down
132 changes: 104 additions & 28 deletions aiida_mlip/parsers/train_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import ast
from pathlib import Path
from typing import Any

from aiida.engine import ExitCode
from aiida.orm import Dict, FolderData
Expand All @@ -25,11 +26,26 @@ class TrainParser(Parser):
Methods
-------
__init__(node: aiida.orm.nodes.process.process.ProcessNode)
Initialize the SPParser instance.
Initialize the TrainParser instance.
parse(**kwargs: Any) -> int:
Parse outputs, store results in the database.
_get_remote_dirs(mlip_dict: [str, Any]) -> [str, Path]:
Get the remote directories based on mlip config file.
_validate_retrieved_files(output_filename: str, model_name: str) -> bool:
Validate that the expected files have been retrieved.
_save_models(model_output: Path, compiled_model_output: Path) -> None:
Save model and compiled model as outputs.
_parse_results(result_name: Path) -> None:
Parse the results file and store the results dictionary.
_save_folders(remote_dirs: [str, Path]) -> None:
Save log and checkpoint folders as outputs.
Returns
-------
int
Expand All @@ -38,12 +54,12 @@ class TrainParser(Parser):
Raises
------
exceptions.ParsingError
If the ProcessNode being passed was not produced by a singlePointCalculation.
If the ProcessNode being passed was not produced by a `Train` Calcjob.
"""

def __init__(self, node: ProcessNode):
"""
Check that the ProcessNode being passed was produced by a `Singlepoint`.
Initialize the TrainParser instance.
Parameters
----------
Expand All @@ -52,11 +68,9 @@ def __init__(self, node: ProcessNode):
"""
super().__init__(node)

# disable for now
# pylint: disable=too-many-locals
def parse(self, **kwargs) -> int:
def parse(self, **kwargs: Any) -> int:
"""
Parse outputs, store results in the database.
Parse outputs and store results in the database.
Parameters
----------
Expand All @@ -68,11 +82,42 @@ def parse(self, **kwargs) -> int:
int
An exit code.
"""
remote_dir = Path(self.node.get_remote_workdir())
print(self.node.inputs.mlip_config)
mlip_dict = self.node.inputs.mlip_config.as_dictionary
remote_dirs = {
typ: remote_dir / mlip_dict.get(f"{typ}_dir", default)
output_filename = self.node.get_option("output_filename")
remote_dirs = self._get_remote_dirs(mlip_dict)

model_output = remote_dirs["model"] / f"{mlip_dict['name']}.model"
compiled_model_output = (
remote_dirs["model"] / f"{mlip_dict['name']}_compiled.model"
)
result_name = remote_dirs["results"] / f"{mlip_dict['name']}_run-2024_train.txt"

if not self._validate_retrieved_files(output_filename, mlip_dict["name"]):
return self.exit_codes.ERROR_MISSING_OUTPUT_FILES

self._save_models(model_output, compiled_model_output)
self._parse_results(result_name)
self._save_folders(remote_dirs)

return ExitCode(0)

def _get_remote_dirs(self, mlip_dict: dict) -> dict:
"""
Get the remote directories based on mlip config file.
Parameters
----------
mlip_dict : dict
Dictionary containing mlip config file.
Returns
-------
dict
Dictionary of remote directories.
"""
rem_dir = Path(self.node.get_remote_workdir())
return {
typ: rem_dir / mlip_dict.get(f"{typ}_dir", default)
for typ, default in (
("log", "logs"),
("checkpoint", "checkpoints"),
Expand All @@ -81,34 +126,61 @@ def parse(self, **kwargs) -> int:
)
}

output_filename = self.node.get_option("output_filename")
model_output = remote_dirs["model"] / f"{mlip_dict['name']}.model"
compiled_model_output = (
remote_dirs["model"] / f"{mlip_dict['name']}_compiled.model"
)
result_name = remote_dirs["results"] / f"{mlip_dict['name']}_run-2024_train.txt"
def _validate_retrieved_files(self, output_filename: str, model_name: str) -> bool:
"""
Validate that the expected files have been retrieved.
# Check that folder content is as expected
Parameters
----------
output_filename : str
The expected output filename.
model_name : str
The name of the model as found in the config file key `name`.
Returns
-------
bool
True if the expected files are retrieved, False otherwise.
"""
files_retrieved = self.retrieved.list_object_names()
files_expected = {output_filename, f"{model_name}.model"}

files_expected = {output_filename}
if not files_expected.issubset(files_retrieved):
self.logger.error(
f"Found files '{files_retrieved}', expected to find '{files_expected}'"
)
return self.exit_codes.ERROR_MISSING_OUTPUT_FILES
return False
return True

# Save models as outputs
# Need to change the architecture
def _save_models(self, model_output: Path, compiled_model_output: Path) -> None:
"""
Save model and compiled model as outputs.
Parameters
----------
model_output : Path
Path to the model output file.
compiled_model_output : Path
Path to the compiled model output file.
"""
architecture = "mace_mp"
model = ModelData.local_file(model_output, architecture=architecture)
compiled_model = ModelData.local_file(
compiled_model_output, architecture=architecture
)

self.out("model", model)
self.out("compiled_model", compiled_model)

# In the result file find the last dictionary
def _parse_results(self, result_name: Path) -> None:
"""
Parse the results file and store the results dictionary.
Parameters
----------
result_name : Path
Path to the result file.
"""
with open(result_name, encoding="utf-8") as file:
last_dict_str = None
for line in file:
Expand All @@ -117,19 +189,23 @@ def parse(self, **kwargs) -> int:
except (SyntaxError, ValueError):
continue

# Convert the last dictionary string to a Dict
if last_dict_str is not None:
results_node = Dict(last_dict_str)
self.out("results_dict", results_node)
else:
raise ValueError("No valid dictionary in the file")

# Save log folder as output
def _save_folders(self, remote_dirs: dict) -> None:
"""
Save log and checkpoint folders as outputs.
Parameters
----------
remote_dirs : dict
Dictionary of remote folders.
"""
log_node = FolderData(tree=remote_dirs["log"])
self.out("logs", log_node)

# Save checkpoint folder as output
checkpoint_node = FolderData(tree=remote_dirs["checkpoint"])
self.out("checkpoints", checkpoint_node)

return ExitCode(0)
2 changes: 1 addition & 1 deletion examples/calculations/submit_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
metadata = {"options": {"resources": {"num_machines": 1}}}
code = load_code("janus@localhost")

# All the other paramenters we want them from the config file
# All the other parameters we want them from the config file
# We want to pass it as a AiiDA data type for the provenance
mlip_config = JanusConfigfile(
(
Expand Down

0 comments on commit c2d531f

Please sign in to comment.