Skip to content

Commit

Permalink
Add fine-tuning option (#142)
Browse files Browse the repository at this point in the history
* Update train CLI help

* Add fine-tuning option for CLI

* Add tests for MLIP training

* Add skip for training if MACE import fails

* Tidy training tests

* Update README

---------

Co-authored-by: ElliottKasoar <ElliottKasoar@users.noreply.github.com>
  • Loading branch information
ElliottKasoar and ElliottKasoar committed May 11, 2024
1 parent c48cdb1 commit adf3dbc
Show file tree
Hide file tree
Showing 13 changed files with 3,790 additions and 4 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,7 @@ dist/
pip-wheel-metadata/
__pycache__/
.tox/
checkpoints/
results/
logs/
*.model
24 changes: 22 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ Tools for machine learnt interatomic potentials
- [ ] Nudge Elastic Band
- [ ] Phonons
- vibroscopy
- [ ] Training ML potentials
- [x] Training ML potentials
- MACE
- [ ] Fine tunning MLIPs
- [x] Fine tunning MLIPs
- MACE
- [ ] Rare events simulations
- PLUMED
Expand Down Expand Up @@ -198,6 +198,26 @@ This will run a singlepoint energy calculation on `KCl.cif` using the [MACE-MP](
> `properties` must be passed as a Yaml list, as above, not as a string.

### Training and fine-tuning MACE models

> [!NOTE]
> This currently requires use of the [develop branch of MACE](https://github.com/ACEsuit/mace/tree/develop).
> This can be installed by running `poetry add git+https://github.com/ACEsuit/mace.git#develop`, followed by `poetry install`.
MACE models can be trained by passing a configuration file to the [MACE CLI](https://github.com/ACEsuit/mace/blob/main/mace/cli/run_train.py):

```shell
janus train --mlip-config /path/to/training/config.yml
```

This will create `logs`, `checkpoints` and `results` folders, as well as saving the trained model, and a compiled version of the model.

Foundational models can also be fine-tuned, by including the `foundation_model` option in your configuration file, and using `--fine-tune` option:

```shell
janus train --mlip-config /path/to/fine/tuning/config.yml --fine-tune
```

## License

[BSD 3-Clause License](LICENSE)
Expand Down
34 changes: 32 additions & 2 deletions janus_core/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,21 @@
from typing import Annotated

from typer import Option, Typer
import yaml

from janus_core.helpers.train import train as run_train

app = Typer()


@app.command(help="Perform single point calculations and save to file.")
@app.command(help="Running training for an MLIP.")
def train(
mlip_config: Annotated[Path, Option(help="Configuration file to pass to MLIP CLI.")]
mlip_config: Annotated[
Path, Option(help="Configuration file to pass to MLIP CLI.")
],
fine_tune: Annotated[
bool, Option(help="Whether to fine-tune a foundational model.")
] = False,
):
"""
Run training for MLIP by passing a configuration file to the MLIP's CLI.
Expand All @@ -21,5 +27,29 @@ def train(
----------
mlip_config : Path
Configuration file to pass to MLIP CLI.
fine_tune : bool
Whether to fine-tune a foundational model. Default is False.
"""
with open(mlip_config, encoding="utf8") as config_file:
config = yaml.safe_load(config_file)

if fine_tune:
if "foundation_model" not in config:
raise ValueError(
"Please include `foundation_model` in your configuration file"
)
if (
config["foundation_model"]
not in ("small", "medium", "large", "small_off", "medium_off", "large_off")
and not Path(config["foundation_model"]).exists()
):
raise ValueError(
"""
Invalid foundational model. Valid options are: 'small', 'medium',
'large', 'small_off', 'medium_off', 'large_off', or a path to the model
"""
)
elif "foundation_model" in config:
raise ValueError("Please include the `--fine-tune` option for fine-tuning")

run_train(mlip_config)
34 changes: 34 additions & 0 deletions tests/data/mlip_fine_tune.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
name: test-finetuned
foundation_model: "tests/models/mace_mp_small.model"
model: ScaleShiftMACE
loss: universal
train_file: "./tests/data/mlip_fine_tune_train.xyz"
valid_file: "./tests/data/mlip_fine_tune_train.xyz"
test_file: "./tests/data/mlip_test.xyz"
E0s: foundation
valid_fraction: 0.05
energy_weight: 1.0
forces_weight: 10.0
stress_weight: 100.0
stress_key: dft_stress
energy_key: dft_energy
forces_key: dft_forces
compute_stress: True
compute_forces: True
clip_grad: 100
error_table: PerAtomRMSE
lr: 0.005
scaling: rms_forces_scaling
batch_size: 4
max_num_epochs: 1
ema: True
ema_decay: 0.995
amsgrad: True
default_dtype: float64
device: cpu
seed: 2024
keep_isolated_atoms: True
keep_checkpoints: False
save_cpu: True
weight_decay: 1e-8
eval_interval: 2
2 changes: 2 additions & 0 deletions tests/data/mlip_fine_tune_invalid_foundation.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
model: MACE
foundation_model: test
1 change: 1 addition & 0 deletions tests/data/mlip_fine_tune_no_foundation.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
model: MACE
1,160 changes: 1,160 additions & 0 deletions tests/data/mlip_fine_tune_train.xyz

Large diffs are not rendered by default.

116 changes: 116 additions & 0 deletions tests/data/mlip_test.xyz
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
114
Lattice="14.656053852204064 0.0 0.0 7.152716897291043 12.83340508269681 0.0 7.3175087671203345 4.21254028622021 12.109471678044383" Properties=species:S:1:pos:R:3:initial_magmoms:R:1:dft_forces:R:3 dft_energy=-842.19986612 dft_stress="-0.002807895578799348 0.001581940291161188 0.0004783320390840146 -0.0013659013941342366 0.0012554940035235348 -0.0011820129849694034" pbc="T T T"
Zr 22.48182938 14.27491151 4.60630202 0.00000000 -0.62996178 0.95368699 0.82369577
Zr 15.33378953 5.70638380 7.53308287 0.00000000 -0.45101320 -1.06264988 0.16373682
Zr 17.02435992 4.55275102 4.58910306 0.00000000 0.65052108 0.84507175 -0.39204271
Zr 20.58667988 15.41317450 7.58829048 0.00000000 1.75845798 -0.11079006 -0.85052821
Zr 24.22636799 15.38223202 7.46185572 0.00000000 0.10864548 0.66981166 1.20893334
Zr 13.56890565 4.59679979 4.57274878 0.00000000 -0.27590197 -0.34587007 0.71125185
H 20.58291280 10.53438721 4.42768471 0.00000000 0.07068061 0.04305158 -0.15985085
H 13.67565519 6.84422131 11.64659825 0.00000000 0.20732998 0.03991176 -0.31304837
H 9.62830143 12.80497036 1.00715589 0.00000000 -0.04385990 0.16508199 -0.16116548
H 16.86796977 9.75523952 7.73425960 0.00000000 1.11522761 -0.27027501 -0.64116013
H 17.28048397 6.95652516 0.85718791 0.00000000 0.78496398 0.08982319 0.31862062
H 9.49504772 4.44745811 4.03855331 0.00000000 -0.31441595 0.49221048 -0.23562432
H 6.58059986 3.33885293 7.77292005 0.00000000 0.14836379 -0.94894842 -0.10347122
H 20.35085942 13.11546456 11.27256415 0.00000000 0.44630031 -0.37808285 0.00280150
H 17.31277155 7.39195042 10.93539770 0.00000000 0.38048834 0.23992616 0.50095745
H 11.84505596 8.79844934 3.92648162 0.00000000 0.09617687 -0.53107740 -0.60378195
H 11.52615792 11.44652110 7.92371163 0.00000000 0.04517207 0.38862373 0.62618192
H 20.78218151 13.07935111 0.92513510 0.00000000 -0.49508261 -0.56466228 -0.09680000
H 24.61711539 12.69539470 10.92214020 0.00000000 0.87000599 -0.05404968 -0.35535968
H 9.09594941 2.80333202 7.95033331 0.00000000 0.10104232 0.01204551 0.33397358
H 6.95092026 4.44476419 3.65551607 -0.00000000 0.18310559 -0.20465031 0.81008866
H 13.07555238 7.34697631 1.08645537 0.00000000 0.23070166 0.31902772 0.26045858
H 19.04198063 8.50951594 4.10283149 0.00000000 0.22792608 -0.27138971 0.01541628
H 18.29329127 16.16844071 11.24261935 -0.00000000 0.37876073 0.16521638 -0.02576211
H 4.78537516 3.56090119 1.12382828 0.00000000 0.36870756 1.08767380 0.60356398
H 18.60691102 11.72729791 7.98856647 0.00000000 -0.55819576 0.07444669 0.03123210
H 11.18445984 3.68366102 1.33577932 0.00000000 -0.49266331 -0.42868630 -0.73637819
H 12.94622738 9.34519669 8.00607741 0.00000000 -0.16835674 0.57625055 -0.28879830
H 10.52404189 10.96839470 3.88410417 0.00000000 0.11172472 -0.27072103 -0.82229649
H 26.41067535 16.36144183 11.08394202 0.00000000 0.20410786 0.29039336 0.13413951
H 22.58187133 13.10363845 7.32993452 0.00000000 0.02512952 0.07186285 -0.07401471
H 12.55413114 5.28549259 7.28487964 0.00000000 -0.02235129 -0.04048921 -0.23684377
H 15.42405498 3.58988145 2.53801188 0.00000000 -0.34947920 0.02384834 0.09021419
H 18.25643048 4.93275546 7.26670564 0.00000000 -0.11203298 0.29129704 -0.20001095
C 19.78778191 10.30495195 5.12809145 0.00000000 -0.54800673 0.65453955 -0.51609416
C 12.80980821 6.26074118 11.93756406 0.00000000 1.11834913 0.55927271 -1.47775984
C 3.28397790 0.57876490 0.58806051 0.00000000 0.05103387 0.19184345 1.52656116
C 17.74488342 9.90117006 7.04283584 0.00000000 0.56656521 -0.05882561 -0.90558873
C 16.40824325 7.46225275 0.49532087 -0.00000000 -0.74538895 -0.91398771 0.65936042
C 8.77120357 4.20140918 4.81130682 -0.00000000 0.83140408 -0.87598045 0.14246587
C 7.23680915 3.28095736 6.89316802 0.00000000 0.11078116 0.00319742 1.51923692
C 21.29830492 12.65237168 11.60248336 0.00000000 0.58187288 0.02507193 -0.16948253
C 18.06206746 6.81921496 11.50113773 0.00000000 0.40747767 -2.07918221 0.14262449
C 11.81127523 9.41107555 4.80673089 0.00000000 -0.10051697 -1.53060007 -1.15003097
C 11.62029585 10.85447232 7.03230801 0.00000000 1.28752895 0.39119803 1.73279861
C 12.71613914 0.70288188 0.52736572 0.00000000 -0.68916790 0.82900637 0.33976653
C 23.73599337 12.40140928 11.44193717 0.00000000 -1.89121688 0.54697591 -0.41856290
C 8.62215975 3.16729236 7.05294541 -0.00000000 -0.33611175 1.00796153 -1.56257503
C 7.40944796 4.13840125 4.62561089 0.00000000 -1.05768620 -0.66358101 -0.04099532
C 14.03605546 7.66493432 0.65574152 0.00000000 -1.96521670 -0.28704619 -0.10888283
C 18.95870959 9.18935834 4.94012070 0.00000000 1.15625723 1.64537886 0.66139756
C 11.42913395 4.27385075 11.72743310 0.00000000 0.50008025 1.00227303 -1.06547335
C 4.64442494 2.59100261 0.73640566 0.00000000 -0.90617630 -1.04531322 0.82790092
C 18.58994989 11.03056553 7.15400818 0.00000000 -0.95367309 -1.13339412 -0.27378905
C 11.27584065 2.79740097 0.66648308 0.00000000 2.19650576 -0.01781540 1.07154831
C 12.47877075 9.74580963 7.08892956 0.00000000 -1.31339283 -1.62844086 -1.02047273
C 11.08255426 10.56684887 4.70838950 0.00000000 -0.10076536 1.45964130 2.96811562
C 19.22716888 4.50423312 11.58317069 0.00000000 -0.55096104 2.18231802 -1.59354484
C 2.76317406 2.40290932 2.27771904 -0.00000000 -0.59775137 -1.58376622 1.27387516
C 17.40404588 7.57296994 5.94324446 -0.00000000 -1.56743116 1.23997613 1.16760305
C 20.22734071 12.51092336 6.15403650 -0.00000000 1.22957509 1.90035653 0.03562680
C 13.23276845 4.69549753 9.96908453 -0.00000000 0.50718825 2.08974222 0.61377288
C 5.10721376 3.61653945 5.70636886 0.00000000 2.73355068 0.65037236 -0.93035569
C 22.49474278 13.84651215 9.75433671 0.00000000 1.36969335 -0.12477390 0.80175651
C 15.21728223 6.34016578 2.42568485 -0.00000000 0.60210491 -0.63015004 -0.17150813
C 10.86509741 3.60814266 6.03571869 -0.00000000 -1.26808337 -0.09303192 -1.39703912
C 10.32059109 12.67417244 5.87289369 -0.00000000 -0.19562271 -1.61251767 0.60019594
C 13.01471446 2.21756005 2.44626754 -0.00000000 0.86092717 -0.42579082 -1.07737949
C 17.57499248 5.00231795 9.71452000 -0.00000000 1.10269277 0.95727595 1.80798168
C 13.09651755 7.51517580 6.02158091 -0.00000000 0.16152275 1.03866058 0.00051518
C 19.52961113 11.26506473 6.14689692 0.00000000 -0.47923424 -2.00926951 -0.00167911
C 12.52016611 5.09627002 11.25061211 0.00000000 -0.85821135 -2.14406407 0.00629697
C 3.56820853 1.83194621 1.20404202 0.00000000 -0.24947921 -0.42605623 -1.37841177
C 18.01743128 9.00143862 5.91933659 0.00000000 -1.89930848 -1.63034887 2.04649579
C 15.19452756 7.15622755 1.19432937 0.00000000 1.27993056 0.82959345 -1.10206065
C 9.38342828 3.65796068 5.96573626 0.00000000 -0.49527771 -0.34871106 0.87187589
C 6.62365389 3.59161306 5.69646438 0.00000000 0.08158964 1.40936005 -1.12924958
C 22.49822633 12.89119359 10.90410513 0.00000000 -0.01534993 0.25440759 0.98005323
C 18.38052004 5.54018958 10.94507371 0.00000000 -0.42962186 -2.37600741 0.12873569
C 12.44485007 8.87941212 5.94338497 0.00000000 0.55319558 1.55865145 1.55436660
C 10.96925659 11.34982105 5.93398978 0.00000000 0.87996439 -1.55717178 -2.73161243
C 12.32836771 1.89965457 1.17431195 0.00000000 -1.28221891 -0.14477468 -0.92883689
O 21.11711550 12.62179089 5.23136551 -0.00000000 0.33710932 0.64341980 0.00434688
O 14.05022859 5.61287803 9.51235314 0.00000000 -0.83755272 -1.62056392 -0.18869006
O 16.49586196 1.64212691 2.81325184 0.00000000 1.33843890 -0.08524281 -1.21436271
O 16.51483781 7.38700121 6.89142490 0.00000000 0.99256559 0.25144287 -0.72190444
O 16.35950251 6.04692579 2.90272957 -0.00000000 0.42654712 -0.18406863 0.28220103
O 11.37664037 2.85711218 6.91666764 -0.00000000 0.32622175 0.18088953 0.59787604
O 4.55849395 4.24912635 4.76601017 -0.00000000 -1.04603443 0.64592324 -0.98906917
O 21.39439287 14.20556624 9.26851925 -0.00000000 -1.02536927 -0.03426866 -0.07702566
O 17.79706063 3.81177830 9.36736657 -0.00000000 -0.07397140 -0.57323519 -0.87936489
O 12.83605234 6.72852854 5.06669023 0.00000000 0.17541133 -0.65136354 -0.15309324
O 10.57914227 13.43859821 6.85772410 0.00000000 -0.60809472 0.75544597 -0.13737370
O 12.86986208 3.39642398 2.81196643 -0.00000000 -1.01417472 2.24284815 1.56594584
O 23.65581139 14.27655500 9.41389153 0.00000000 0.00268412 0.04891899 -0.81907354
O 11.42472996 4.24906649 5.07945179 -0.00000000 1.02882190 -0.07815029 0.18245872
O 4.56702803 3.07814059 6.66707705 -0.00000000 -1.19995405 -1.71125358 2.55649186
O 14.11609555 5.89328931 2.87465801 -0.00000000 -0.82170982 0.25217877 0.18716108
O 17.73638295 6.70591416 5.10280868 0.00000000 0.49208963 -0.02062198 -0.46226448
O 12.87622968 3.59181029 9.44347702 -0.00000000 0.08008429 0.41248775 -0.15837691
O 3.11030822 3.53227393 2.72475726 0.00000000 -0.26757271 0.53939021 0.12620310
O 19.93588034 13.43377070 7.02852626 0.00000000 -0.07438593 -1.58871488 -0.77938934
O 13.84382844 1.39972231 2.97637902 -0.00000000 -0.37439899 -0.80008548 -0.03503645
O 13.86369820 7.28727219 7.04104684 -0.00000000 -0.51524344 0.22394783 -0.48768489
O 9.51842754 12.77225733 4.87638131 -0.00000000 -0.68242216 1.24384159 0.40002763
O 16.76119477 5.83819930 9.20759103 -0.00000000 -0.48896777 0.11154905 -0.30434738
O 15.26378930 3.61842800 3.49257841 -0.00000000 0.38698199 0.17880403 1.59589744
O 17.39344690 4.73233843 6.86898526 -0.00000000 -0.75289662 -0.06223975 -0.03985463
O 22.57211935 14.00424155 6.96757462 -0.00000000 0.04997077 0.26922446 -0.71939579
O 13.26041567 4.74033158 6.89718949 -0.00000000 0.16239629 0.28002000 0.44618120
O 15.25455144 3.61728206 7.77094086 -0.00000000 0.93027009 -0.35770884 0.29376541
O 13.83448042 2.67300898 5.54465998 -0.00000000 0.72276470 0.39229014 -1.20785068
O 15.34997262 5.24669185 5.48218276 -0.00000000 -0.82219018 0.10716146 -0.00022242
O 16.88908477 2.63088792 5.45947486 -0.00000000 -1.12156467 0.54434099 -1.53204722
Loading

0 comments on commit adf3dbc

Please sign in to comment.