Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fine-tuning option #142

Merged
merged 6 commits into from
May 11, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,7 @@ dist/
pip-wheel-metadata/
__pycache__/
.tox/
checkpoints/
results/
logs/
*.model
34 changes: 32 additions & 2 deletions janus_core/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,21 @@
from typing import Annotated

from typer import Option, Typer
import yaml

from janus_core.helpers.train import train as run_train

app = Typer()


@app.command(help="Perform single point calculations and save to file.")
@app.command(help="Running training for an MLIP.")
def train(
mlip_config: Annotated[Path, Option(help="Configuration file to pass to MLIP CLI.")]
mlip_config: Annotated[
Path, Option(help="Configuration file to pass to MLIP CLI.")
],
fine_tune: Annotated[
bool, Option(help="Whether to fine-tune a foundational model.")
] = False,
):
"""
Run training for MLIP by passing a configuration file to the MLIP's CLI.
Expand All @@ -21,5 +27,29 @@ def train(
----------
mlip_config : Path
Configuration file to pass to MLIP CLI.
fine_tune : bool
Whether to fine-tune a foundational model. Default is False.
"""
with open(mlip_config, encoding="utf8") as config_file:
config = yaml.safe_load(config_file)

if fine_tune:
if "foundation_model" not in config:
raise ValueError(
"Please include `foundation_model` in your configuration file"
)
if (
config["foundation_model"]
not in ("small", "medium", "large", "small_off", "medium_off", "large_off")
and not Path(config["foundation_model"]).exists()
):
raise ValueError(
"""
Invalid foundational model. Valid options are: 'small', 'medium',
'large', 'small_off', 'medium_off', 'large_off', or a path to the model
"""
)
elif "foundation_model" in config:
raise ValueError("Please include the `--fine-tune` option for fine-tuning")

run_train(mlip_config)
34 changes: 34 additions & 0 deletions tests/data/mlip_fine_tune.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
name: test-finetuned
foundation_model: "tests/models/mace_mp_small.model"
model: ScaleShiftMACE
loss: universal
train_file: "./tests/data/mlip_fine_tune_train.xyz"
valid_file: "./tests/data/mlip_fine_tune_train.xyz"
test_file: "./tests/data/mlip_test.xyz"
E0s: foundation
valid_fraction: 0.05
energy_weight: 1.0
forces_weight: 10.0
stress_weight: 100.0
stress_key: dft_stress
energy_key: dft_energy
forces_key: dft_forces
compute_stress: True
compute_forces: True
clip_grad: 100
error_table: PerAtomRMSE
lr: 0.005
scaling: rms_forces_scaling
batch_size: 4
max_num_epochs: 1
ema: True
ema_decay: 0.995
amsgrad: True
default_dtype: float64
alinelena marked this conversation as resolved.
Show resolved Hide resolved
device: cpu
seed: 2024
keep_isolated_atoms: True
keep_checkpoints: False
save_cpu: True
weight_decay: 1e-8
eval_interval: 2
2 changes: 2 additions & 0 deletions tests/data/mlip_fine_tune_invalid_foundation.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
model: MACE
foundation_model: test
1 change: 1 addition & 0 deletions tests/data/mlip_fine_tune_no_foundation.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
model: MACE
1,160 changes: 1,160 additions & 0 deletions tests/data/mlip_fine_tune_train.xyz

Large diffs are not rendered by default.

116 changes: 116 additions & 0 deletions tests/data/mlip_test.xyz
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
114
Lattice="14.656053852204064 0.0 0.0 7.152716897291043 12.83340508269681 0.0 7.3175087671203345 4.21254028622021 12.109471678044383" Properties=species:S:1:pos:R:3:initial_magmoms:R:1:dft_forces:R:3 dft_energy=-842.19986612 dft_stress="-0.002807895578799348 0.001581940291161188 0.0004783320390840146 -0.0013659013941342366 0.0012554940035235348 -0.0011820129849694034" pbc="T T T"
Zr 22.48182938 14.27491151 4.60630202 0.00000000 -0.62996178 0.95368699 0.82369577
Zr 15.33378953 5.70638380 7.53308287 0.00000000 -0.45101320 -1.06264988 0.16373682
Zr 17.02435992 4.55275102 4.58910306 0.00000000 0.65052108 0.84507175 -0.39204271
Zr 20.58667988 15.41317450 7.58829048 0.00000000 1.75845798 -0.11079006 -0.85052821
Zr 24.22636799 15.38223202 7.46185572 0.00000000 0.10864548 0.66981166 1.20893334
Zr 13.56890565 4.59679979 4.57274878 0.00000000 -0.27590197 -0.34587007 0.71125185
H 20.58291280 10.53438721 4.42768471 0.00000000 0.07068061 0.04305158 -0.15985085
H 13.67565519 6.84422131 11.64659825 0.00000000 0.20732998 0.03991176 -0.31304837
H 9.62830143 12.80497036 1.00715589 0.00000000 -0.04385990 0.16508199 -0.16116548
H 16.86796977 9.75523952 7.73425960 0.00000000 1.11522761 -0.27027501 -0.64116013
H 17.28048397 6.95652516 0.85718791 0.00000000 0.78496398 0.08982319 0.31862062
H 9.49504772 4.44745811 4.03855331 0.00000000 -0.31441595 0.49221048 -0.23562432
H 6.58059986 3.33885293 7.77292005 0.00000000 0.14836379 -0.94894842 -0.10347122
H 20.35085942 13.11546456 11.27256415 0.00000000 0.44630031 -0.37808285 0.00280150
H 17.31277155 7.39195042 10.93539770 0.00000000 0.38048834 0.23992616 0.50095745
H 11.84505596 8.79844934 3.92648162 0.00000000 0.09617687 -0.53107740 -0.60378195
H 11.52615792 11.44652110 7.92371163 0.00000000 0.04517207 0.38862373 0.62618192
H 20.78218151 13.07935111 0.92513510 0.00000000 -0.49508261 -0.56466228 -0.09680000
H 24.61711539 12.69539470 10.92214020 0.00000000 0.87000599 -0.05404968 -0.35535968
H 9.09594941 2.80333202 7.95033331 0.00000000 0.10104232 0.01204551 0.33397358
H 6.95092026 4.44476419 3.65551607 -0.00000000 0.18310559 -0.20465031 0.81008866
H 13.07555238 7.34697631 1.08645537 0.00000000 0.23070166 0.31902772 0.26045858
H 19.04198063 8.50951594 4.10283149 0.00000000 0.22792608 -0.27138971 0.01541628
H 18.29329127 16.16844071 11.24261935 -0.00000000 0.37876073 0.16521638 -0.02576211
H 4.78537516 3.56090119 1.12382828 0.00000000 0.36870756 1.08767380 0.60356398
H 18.60691102 11.72729791 7.98856647 0.00000000 -0.55819576 0.07444669 0.03123210
H 11.18445984 3.68366102 1.33577932 0.00000000 -0.49266331 -0.42868630 -0.73637819
H 12.94622738 9.34519669 8.00607741 0.00000000 -0.16835674 0.57625055 -0.28879830
H 10.52404189 10.96839470 3.88410417 0.00000000 0.11172472 -0.27072103 -0.82229649
H 26.41067535 16.36144183 11.08394202 0.00000000 0.20410786 0.29039336 0.13413951
H 22.58187133 13.10363845 7.32993452 0.00000000 0.02512952 0.07186285 -0.07401471
H 12.55413114 5.28549259 7.28487964 0.00000000 -0.02235129 -0.04048921 -0.23684377
H 15.42405498 3.58988145 2.53801188 0.00000000 -0.34947920 0.02384834 0.09021419
H 18.25643048 4.93275546 7.26670564 0.00000000 -0.11203298 0.29129704 -0.20001095
C 19.78778191 10.30495195 5.12809145 0.00000000 -0.54800673 0.65453955 -0.51609416
C 12.80980821 6.26074118 11.93756406 0.00000000 1.11834913 0.55927271 -1.47775984
C 3.28397790 0.57876490 0.58806051 0.00000000 0.05103387 0.19184345 1.52656116
C 17.74488342 9.90117006 7.04283584 0.00000000 0.56656521 -0.05882561 -0.90558873
C 16.40824325 7.46225275 0.49532087 -0.00000000 -0.74538895 -0.91398771 0.65936042
C 8.77120357 4.20140918 4.81130682 -0.00000000 0.83140408 -0.87598045 0.14246587
C 7.23680915 3.28095736 6.89316802 0.00000000 0.11078116 0.00319742 1.51923692
C 21.29830492 12.65237168 11.60248336 0.00000000 0.58187288 0.02507193 -0.16948253
C 18.06206746 6.81921496 11.50113773 0.00000000 0.40747767 -2.07918221 0.14262449
C 11.81127523 9.41107555 4.80673089 0.00000000 -0.10051697 -1.53060007 -1.15003097
C 11.62029585 10.85447232 7.03230801 0.00000000 1.28752895 0.39119803 1.73279861
C 12.71613914 0.70288188 0.52736572 0.00000000 -0.68916790 0.82900637 0.33976653
C 23.73599337 12.40140928 11.44193717 0.00000000 -1.89121688 0.54697591 -0.41856290
C 8.62215975 3.16729236 7.05294541 -0.00000000 -0.33611175 1.00796153 -1.56257503
C 7.40944796 4.13840125 4.62561089 0.00000000 -1.05768620 -0.66358101 -0.04099532
C 14.03605546 7.66493432 0.65574152 0.00000000 -1.96521670 -0.28704619 -0.10888283
C 18.95870959 9.18935834 4.94012070 0.00000000 1.15625723 1.64537886 0.66139756
C 11.42913395 4.27385075 11.72743310 0.00000000 0.50008025 1.00227303 -1.06547335
C 4.64442494 2.59100261 0.73640566 0.00000000 -0.90617630 -1.04531322 0.82790092
C 18.58994989 11.03056553 7.15400818 0.00000000 -0.95367309 -1.13339412 -0.27378905
C 11.27584065 2.79740097 0.66648308 0.00000000 2.19650576 -0.01781540 1.07154831
C 12.47877075 9.74580963 7.08892956 0.00000000 -1.31339283 -1.62844086 -1.02047273
C 11.08255426 10.56684887 4.70838950 0.00000000 -0.10076536 1.45964130 2.96811562
C 19.22716888 4.50423312 11.58317069 0.00000000 -0.55096104 2.18231802 -1.59354484
C 2.76317406 2.40290932 2.27771904 -0.00000000 -0.59775137 -1.58376622 1.27387516
C 17.40404588 7.57296994 5.94324446 -0.00000000 -1.56743116 1.23997613 1.16760305
C 20.22734071 12.51092336 6.15403650 -0.00000000 1.22957509 1.90035653 0.03562680
C 13.23276845 4.69549753 9.96908453 -0.00000000 0.50718825 2.08974222 0.61377288
C 5.10721376 3.61653945 5.70636886 0.00000000 2.73355068 0.65037236 -0.93035569
C 22.49474278 13.84651215 9.75433671 0.00000000 1.36969335 -0.12477390 0.80175651
C 15.21728223 6.34016578 2.42568485 -0.00000000 0.60210491 -0.63015004 -0.17150813
C 10.86509741 3.60814266 6.03571869 -0.00000000 -1.26808337 -0.09303192 -1.39703912
C 10.32059109 12.67417244 5.87289369 -0.00000000 -0.19562271 -1.61251767 0.60019594
C 13.01471446 2.21756005 2.44626754 -0.00000000 0.86092717 -0.42579082 -1.07737949
C 17.57499248 5.00231795 9.71452000 -0.00000000 1.10269277 0.95727595 1.80798168
C 13.09651755 7.51517580 6.02158091 -0.00000000 0.16152275 1.03866058 0.00051518
C 19.52961113 11.26506473 6.14689692 0.00000000 -0.47923424 -2.00926951 -0.00167911
C 12.52016611 5.09627002 11.25061211 0.00000000 -0.85821135 -2.14406407 0.00629697
C 3.56820853 1.83194621 1.20404202 0.00000000 -0.24947921 -0.42605623 -1.37841177
C 18.01743128 9.00143862 5.91933659 0.00000000 -1.89930848 -1.63034887 2.04649579
C 15.19452756 7.15622755 1.19432937 0.00000000 1.27993056 0.82959345 -1.10206065
C 9.38342828 3.65796068 5.96573626 0.00000000 -0.49527771 -0.34871106 0.87187589
C 6.62365389 3.59161306 5.69646438 0.00000000 0.08158964 1.40936005 -1.12924958
C 22.49822633 12.89119359 10.90410513 0.00000000 -0.01534993 0.25440759 0.98005323
C 18.38052004 5.54018958 10.94507371 0.00000000 -0.42962186 -2.37600741 0.12873569
C 12.44485007 8.87941212 5.94338497 0.00000000 0.55319558 1.55865145 1.55436660
C 10.96925659 11.34982105 5.93398978 0.00000000 0.87996439 -1.55717178 -2.73161243
C 12.32836771 1.89965457 1.17431195 0.00000000 -1.28221891 -0.14477468 -0.92883689
O 21.11711550 12.62179089 5.23136551 -0.00000000 0.33710932 0.64341980 0.00434688
O 14.05022859 5.61287803 9.51235314 0.00000000 -0.83755272 -1.62056392 -0.18869006
O 16.49586196 1.64212691 2.81325184 0.00000000 1.33843890 -0.08524281 -1.21436271
O 16.51483781 7.38700121 6.89142490 0.00000000 0.99256559 0.25144287 -0.72190444
O 16.35950251 6.04692579 2.90272957 -0.00000000 0.42654712 -0.18406863 0.28220103
O 11.37664037 2.85711218 6.91666764 -0.00000000 0.32622175 0.18088953 0.59787604
O 4.55849395 4.24912635 4.76601017 -0.00000000 -1.04603443 0.64592324 -0.98906917
O 21.39439287 14.20556624 9.26851925 -0.00000000 -1.02536927 -0.03426866 -0.07702566
O 17.79706063 3.81177830 9.36736657 -0.00000000 -0.07397140 -0.57323519 -0.87936489
O 12.83605234 6.72852854 5.06669023 0.00000000 0.17541133 -0.65136354 -0.15309324
O 10.57914227 13.43859821 6.85772410 0.00000000 -0.60809472 0.75544597 -0.13737370
O 12.86986208 3.39642398 2.81196643 -0.00000000 -1.01417472 2.24284815 1.56594584
O 23.65581139 14.27655500 9.41389153 0.00000000 0.00268412 0.04891899 -0.81907354
O 11.42472996 4.24906649 5.07945179 -0.00000000 1.02882190 -0.07815029 0.18245872
O 4.56702803 3.07814059 6.66707705 -0.00000000 -1.19995405 -1.71125358 2.55649186
O 14.11609555 5.89328931 2.87465801 -0.00000000 -0.82170982 0.25217877 0.18716108
O 17.73638295 6.70591416 5.10280868 0.00000000 0.49208963 -0.02062198 -0.46226448
O 12.87622968 3.59181029 9.44347702 -0.00000000 0.08008429 0.41248775 -0.15837691
O 3.11030822 3.53227393 2.72475726 0.00000000 -0.26757271 0.53939021 0.12620310
O 19.93588034 13.43377070 7.02852626 0.00000000 -0.07438593 -1.58871488 -0.77938934
O 13.84382844 1.39972231 2.97637902 -0.00000000 -0.37439899 -0.80008548 -0.03503645
O 13.86369820 7.28727219 7.04104684 -0.00000000 -0.51524344 0.22394783 -0.48768489
O 9.51842754 12.77225733 4.87638131 -0.00000000 -0.68242216 1.24384159 0.40002763
O 16.76119477 5.83819930 9.20759103 -0.00000000 -0.48896777 0.11154905 -0.30434738
O 15.26378930 3.61842800 3.49257841 -0.00000000 0.38698199 0.17880403 1.59589744
O 17.39344690 4.73233843 6.86898526 -0.00000000 -0.75289662 -0.06223975 -0.03985463
O 22.57211935 14.00424155 6.96757462 -0.00000000 0.04997077 0.26922446 -0.71939579
O 13.26041567 4.74033158 6.89718949 -0.00000000 0.16239629 0.28002000 0.44618120
O 15.25455144 3.61728206 7.77094086 -0.00000000 0.93027009 -0.35770884 0.29376541
O 13.83448042 2.67300898 5.54465998 -0.00000000 0.72276470 0.39229014 -1.20785068
O 15.34997262 5.24669185 5.48218276 -0.00000000 -0.82219018 0.10716146 -0.00022242
O 16.88908477 2.63088792 5.45947486 -0.00000000 -1.12156467 0.54434099 -1.53204722