diff --git a/torchmdnet/datasets/ace.py b/torchmdnet/datasets/ace.py index 9ab3ecbaf..e2470dcb2 100644 --- a/torchmdnet/datasets/ace.py +++ b/torchmdnet/datasets/ace.py @@ -162,7 +162,7 @@ def sample_iter(self, mol_ids=False): mols = list(h5.values())[0].items() load_confs = self._load_confs_2_0 else: - raise RuntimeError(f"Unsuported layout verions: {version}") + raise RuntimeError(f"Unsupported layout version: {version}") # Iterate over the molecules for i_mol, (mol_id, mol) in tqdm( @@ -180,7 +180,7 @@ def sample_iter(self, mol_ids=False): fq = pt.tensor(mol["formal_charges"], dtype=pt.long) q = fq.sum() - for pos, y, neg_dy, pq, dp in load_confs(mol, n_atoms=len(z)): + for i_conf, (pos, y, neg_dy, pq, dp) in enumerate(load_confs(mol, n_atoms=len(z))): # Skip samples with large forces if self.max_gradient: @@ -193,6 +193,7 @@ def sample_iter(self, mol_ids=False): ) if mol_ids: args["mol_id"] = mol_id + args["i_conf"] = i_conf data = Data(**args) if self.pre_filter is not None and not self.pre_filter(data):