Skip to content

Commit

Permalink
Merge pull request #174 from raimis/ds_ace_3
Browse files Browse the repository at this point in the history
Enable the Ace loader to return a conformation index
  • Loading branch information
RaulPPelaez committed May 25, 2023
2 parents 9b92392 + 1ca46c0 commit 722babd
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions torchmdnet/datasets/ace.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def sample_iter(self, mol_ids=False):
mols = list(h5.values())[0].items()
load_confs = self._load_confs_2_0
else:
raise RuntimeError(f"Unsuported layout verions: {version}")
raise RuntimeError(f"Unsupported layout version: {version}")

# Iterate over the molecules
for i_mol, (mol_id, mol) in tqdm(
Expand All @@ -180,7 +180,7 @@ def sample_iter(self, mol_ids=False):
fq = pt.tensor(mol["formal_charges"], dtype=pt.long)
q = fq.sum()

for pos, y, neg_dy, pq, dp in load_confs(mol, n_atoms=len(z)):
for i_conf, (pos, y, neg_dy, pq, dp) in enumerate(load_confs(mol, n_atoms=len(z))):

# Skip samples with large forces
if self.max_gradient:
Expand All @@ -193,6 +193,7 @@ def sample_iter(self, mol_ids=False):
)
if mol_ids:
args["mol_id"] = mol_id
args["i_conf"] = i_conf
data = Data(**args)

if self.pre_filter is not None and not self.pre_filter(data):
Expand Down

0 comments on commit 722babd

Please sign in to comment.