Skip to content

Commit

Permalink
Merge pull request #132 from usnistgov/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
knc6 committed Aug 11, 2023
2 parents 4f14235 + 02c0fca commit d1415cf
Show file tree
Hide file tree
Showing 11 changed files with 132 additions and 37 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ python -m pip install -e .

As an alternate method, ALIGNN can also be installed using `pip` command as follows:
```
python -m pip install alignn
pip install alignn
pip install dgl==1.0.1+cu117 -f https://data.dgl.ai/wheels/cu117/repo.html
```

<a name="example"></a>
Expand Down
2 changes: 1 addition & 1 deletion alignn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
"""Version number."""
__version__ = "2023.07.10"
__version__ = "2023.08.01"
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
"cutoff": 8.0,
"max_neighbors": 12,
"keep_data_order": true,
"distributed":true,
"distributed":false,
"model": {
"name": "alignn_atomwise",
"atom_input_features": 92,
Expand Down
8 changes: 8 additions & 0 deletions alignn/ff/ff.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@

__author__ = "Kamal Choudhary, Brian DeCost, Keith Butler, Lily Major"

scf_fd_top_10_en_42_fmax_600_wt01 = (
"https://figshare.com/ndownloader/files/41967375"
)
scf_fd_top_10_en_42_fmax_600_wt10 = (
"https://figshare.com/ndownloader/files/41967372"
)
all_models_ff = {
"alignnff_fmult": "https://figshare.com/ndownloader/files/41583585",
"alignnff_wt10": "https://figshare.com/ndownloader/files/41583594",
Expand All @@ -72,6 +78,8 @@
"alignnff_wt1": "https://figshare.com/ndownloader/files/41583591",
"fmult_mlearn_only": "https://figshare.com/ndownloader/files/41583597",
"revised": "https://figshare.com/ndownloader/files/41583600",
"scf_fd_top_10_en_42_fmax_600_wt01": scf_fd_top_10_en_42_fmax_600_wt01,
"scf_fd_top_10_en_42_fmax_600_wt10": scf_fd_top_10_en_42_fmax_600_wt10,
}


Expand Down
9 changes: 9 additions & 0 deletions alignn/models/alignn_atomwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from torch import nn
from torch.nn import functional as F
from alignn.models.utils import RBFExpansion
from alignn.graphs import compute_bond_cosines
from alignn.utils import BaseSettings


Expand Down Expand Up @@ -51,6 +52,7 @@ class ALIGNNAtomWiseConfig(BaseSettings):
inner_cutoff: float = 6 # Ansgtrom
stress_multiplier: float = 1
add_reverse_forces: bool = False # will make True as default soon
lg_on_fly: bool = False # will make True as default soon
batch_stress: bool = True

class Config:
Expand Down Expand Up @@ -328,6 +330,13 @@ def forward(
r = g.edata["r"]
if self.config.calculate_gradient:
r.requires_grad_(True)
if self.config.lg_on_fly and len(self.alignn_layers) > 0:
# re-compute bond angle cosines here to ensure
# the three-body interactions are fully included
# in the autograd graph. don't rely on dataloader/caching.
lg.ndata["r"] = r # overwrites precomputed r values
lg.apply_edges(compute_bond_cosines) # overwrites precomputed h
z = self.angle_embedding(lg.edata.pop("h"))

# r = g.edata["r"].clone().detach().requires_grad_(True)
bondlength = torch.norm(r, dim=1)
Expand Down
10 changes: 10 additions & 0 deletions alignn/pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,14 @@
"https://figshare.com/ndownloader/files/31458814",
1,
],
"tinnet_O_alignn": ["https://figshare.com/ndownloader/files/41962800", 1],
"tinnet_N_alignn": ["https://figshare.com/ndownloader/files/41962797", 1],
"tinnet_OH_alignn": ["https://figshare.com/ndownloader/files/41962803", 1],
"AGRA_O_alignn": ["https://figshare.com/ndownloader/files/41966619", 1],
"AGRA_OH_alignn": ["https://figshare.com/ndownloader/files/41966610", 1],
"AGRA_CHO_alignn": ["https://figshare.com/ndownloader/files/41966643", 1],
"AGRA_CO_alignn": ["https://figshare.com/ndownloader/files/41966634", 1],
"AGRA_COOH_alignn": ["https://figshare.com/ndownloader/41966646", 1],
"qm9_U0_alignn": ["https://figshare.com/ndownloader/files/31459054", 1],
"qm9_U_alignn": ["https://figshare.com/ndownloader/files/31459051", 1],
"qm9_alpha_alignn": ["https://figshare.com/ndownloader/files/31459027", 1],
Expand Down Expand Up @@ -155,6 +163,8 @@
1,
],
"ocp2020_all": ["https://figshare.com/ndownloader/files/41411025", 1],
"ocp2020_100k": ["https://figshare.com/ndownloader/files/41967303", 1],
"ocp2020_10k": ["https://figshare.com/ndownloader/files/41967330", 1],
"jv_pdos_alignn": [
"https://figshare.com/ndownloader/files/36757005",
66,
Expand Down
2 changes: 1 addition & 1 deletion alignn/run_alignn_ff.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
parser.add_argument(
"--device",
default=None,
help="set device for executing the model [e.g. cpu, cuda, cuda:2]"
help="set device for executing the model [e.g. cpu, cuda, cuda:2]",
)

if __name__ == "__main__":
Expand Down
64 changes: 47 additions & 17 deletions alignn/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,20 +59,17 @@
import json
import pprint


# from accelerate import Accelerator
import os
import warnings


warnings.filterwarnings("ignore", category=RuntimeWarning)
# from sklearn.decomposition import PCA, KernelPCA
# from sklearn.preprocessing import StandardScaler

# torch config
torch.set_default_dtype(torch.float32)


device = "cpu"
if torch.cuda.is_available():
device = torch.device("cuda")
Expand All @@ -90,9 +87,8 @@ def make_standard_scalar_and_pca(output):
"""Use standard scalar and PCS for multi-output data."""
sc = pk.load(open(os.path.join(tmp_output_dir, "sc.pkl"), "rb"))
y_pred, y = output
y_pred = torch.tensor(sc.transform(y_pred.cpu().numpy()),
device=y_pred.device)
y = torch.tensor(sc.transform(y.cpu().numpy()), device=y.device)
y_pred = torch.tensor(sc.transform(y_pred.cpu().numpy()), device=device)
y = torch.tensor(sc.transform(y.cpu().numpy()), device=device)
# pc = pk.load(open("pca.pkl", "rb"))
# y_pred = torch.tensor(pc.transform(y_pred), device=device)
# y = torch.tensor(pc.transform(y), device=device)
Expand Down Expand Up @@ -791,6 +787,11 @@ def get_batch_errors(dat=[]):
filename=os.path.join(config.output_dir, "Test_results.json"),
data=test_result,
)
last_model_name = "last_model.pt"
torch.save(
net.state_dict(),
os.path.join(config.output_dir, last_model_name),
)
return test_result

if config.distributed:
Expand Down Expand Up @@ -951,13 +952,44 @@ def zig_prediction_transform(x):
"lr_scheduler": scheduler,
"trainer": trainer,
}
handler = Checkpoint(
to_save,
DiskSaver(checkpoint_dir, create_dir=True, require_empty=False),
n_saved=2,
global_step_transform=lambda *_: trainer.state.epoch,
if classification:

def cp_score(engine):
"""Higher accuracy is better."""
return engine.state.metrics["accuracy"]

else:

def cp_score(engine):
"""Lower MAE is better."""
return -engine.state.metrics["mae"]

# save last two epochs
evaluator.add_event_handler(
Events.EPOCH_COMPLETED,
Checkpoint(
to_save,
DiskSaver(
checkpoint_dir, create_dir=True, require_empty=False
),
n_saved=2,
global_step_transform=lambda *_: trainer.state.epoch,
),
)
# save best model
evaluator.add_event_handler(
Events.EPOCH_COMPLETED,
Checkpoint(
to_save,
DiskSaver(
checkpoint_dir, create_dir=True, require_empty=False
),
filename_pattern="best_model.{ext}",
n_saved=1,
global_step_transform=lambda *_: trainer.state.epoch,
score_function=cp_score,
),
)
trainer.add_event_handler(Events.EPOCH_COMPLETED, handler)
if config.progress:
pbar = ProgressBar()
pbar.attach(trainer, output_transform=lambda x: {"loss": x})
Expand Down Expand Up @@ -1028,13 +1060,13 @@ def log_results(engine):

def es_score(engine):
"""Higher accuracy is better."""
engine.state.metrics["accuracy"]
return engine.state.metrics["accuracy"]

else:

def es_score(engine):
"""Lower MAE is better."""
-engine.state.metrics["mae"]
return -engine.state.metrics["mae"]

es_handler = EarlyStopping(
patience=config.n_early_stopping,
Expand Down Expand Up @@ -1067,7 +1099,7 @@ def es_score(engine):
test_loss = evaluator.state.metrics["loss"]
tb_logger.writer.add_hparams(config, {"hparam/test_loss": test_loss})
tb_logger.close()
if config.write_predictions and classification and test_loader is not None:
if config.write_predictions and classification:
net.eval()
f = open(
os.path.join(config.output_dir, "prediction_results_test_set.csv"),
Expand Down Expand Up @@ -1104,7 +1136,6 @@ def es_score(engine):
config.write_predictions
and not classification
and config.model.output_features > 1
and test_loader is not None
):
net.eval()
mem = []
Expand Down Expand Up @@ -1135,7 +1166,6 @@ def es_score(engine):
config.write_predictions
and not classification
and config.model.output_features == 1
and test_loader is not None
):
net.eval()
f = open(
Expand Down
33 changes: 32 additions & 1 deletion alignn/train_folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@
from alignn.config import TrainingConfig
from jarvis.db.jsonutils import loadjson
import argparse
import glob
import torch

device = "cpu"
if torch.cuda.is_available():
device = torch.device("cuda")


parser = argparse.ArgumentParser(
description="Atomistic Line Graph Neural Network"
Expand Down Expand Up @@ -60,7 +67,12 @@
parser.add_argument(
"--device",
default=None,
help="set device for training the model [e.g. cpu, cuda, cuda:2]"
help="set device for training the model [e.g. cpu, cuda, cuda:2]",
)
parser.add_argument(
"--restart_model_path",
default=None,
help="Checkpoint file path for model",
)


Expand All @@ -71,6 +83,7 @@ def train_for_folder(
classification_threshold=None,
batch_size=None,
epochs=None,
restart_model_path=None,
file_format="poscar",
output_dir=None,
):
Expand All @@ -93,6 +106,22 @@ def train_for_folder(
config.batch_size = int(batch_size)
if epochs is not None:
config.epochs = int(epochs)
if restart_model_path is not None:
print("Restarting model from:", restart_model_path)
from alignn.models.alignn import ALIGNN, ALIGNNConfig

rest_config = loadjson(os.path.join(restart_model_path, "config.json"))
print("rest_config", rest_config)
model = ALIGNN(ALIGNNConfig(**rest_config["model"]))
chk_glob = os.path.join(restart_model_path, "*.pt")
tmp = "na"
for i in glob.glob(chk_glob):
tmp = i
print("Checkpoint file", tmp)
model.load_state_dict(torch.load(tmp, map_location=device)["model"])
model.to(device)
else:
model = None
with open(id_prop_dat, "r") as f:
reader = csv.reader(f)
data = [row for row in reader]
Expand Down Expand Up @@ -185,6 +214,7 @@ def train_for_folder(
t1 = time.time()
train_dgl(
config,
model,
train_val_test_loaders=[
train_loader,
val_loader,
Expand All @@ -209,4 +239,5 @@ def train_for_folder(
batch_size=(args.batch_size),
epochs=(args.epochs),
file_format=(args.file_format),
restart_model_path=(args.restart_model_path),
)
34 changes: 20 additions & 14 deletions alignn/train_folder_ff.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@
parser.add_argument(
"--device",
default=None,
help="set device for training the model [e.g. cpu, cuda, cuda:2]"
help="set device for training the model [e.g. cpu, cuda, cuda:2]",
)


Expand Down Expand Up @@ -218,22 +218,28 @@ def train_for_folder(

model = None
if restart_model_path is not None:
# Should be best_model.pt file
print("Restarting the model training:", restart_model_path)
if config.model.name == "alignn_atomwise":
tmp = ALIGNNAtomWiseConfig(
name="alignn_atomwise",
output_features=config.model.output_features,
alignn_layers=config.model.alignn_layers,
atomwise_weight=config.model.atomwise_weight,
stresswise_weight=config.model.stresswise_weight,
graphwise_weight=config.model.graphwise_weight,
gradwise_weight=config.model.gradwise_weight,
gcn_layers=config.model.gcn_layers,
atom_input_features=config.model.atom_input_features,
edge_input_features=config.model.edge_input_features,
triplet_input_features=config.model.triplet_input_features,
embedding_features=config.model.embedding_features,
rest_config = loadjson(
restart_model_path.replace("best_model.pt", "config.json")
)

tmp = ALIGNNAtomWiseConfig(**rest_config["model"])
# tmp = ALIGNNAtomWiseConfig(
# name="alignn_atomwise",
# output_features=config.model.output_features,
# alignn_layers=config.model.alignn_layers,
# atomwise_weight=config.model.atomwise_weight,
# stresswise_weight=config.model.stresswise_weight,
# graphwise_weight=config.model.graphwise_weight,
# gradwise_weight=config.model.gradwise_weight,
# gcn_layers=config.model.gcn_layers,
# atom_input_features=config.model.atom_input_features,
# edge_input_features=config.model.edge_input_features,
# triplet_input_features=config.model.triplet_input_features,
# embedding_features=config.model.embedding_features,
# )
print("Rest config", tmp)
# for i,j in config_dict['model'].items():
# print ('i',i)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

setuptools.setup(
name="alignn",
version="2023.07.10",
version="2023.08.01",
author="Kamal Choudhary, Brian DeCost",
author_email="kamal.choudhary@nist.gov",
description="alignn",
Expand Down

0 comments on commit d1415cf

Please sign in to comment.