Skip to content

Commit

Permalink
Merge pull request #185 from AntonioMirarchi/update_spice_dataloader
Browse files Browse the repository at this point in the history
fix torch tensor warning in spice dataloader
  • Loading branch information
RaulPPelaez committed Jun 9, 2023
2 parents 237b4fe + c7290d0 commit d3f6f22
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions torchmdnet/datasets/spice.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,17 +120,17 @@ def sample_iter(self, mol_ids=False):
if i_mol % self.subsample_molecules != 0:
continue

z = pt.tensor(mol["atomic_numbers"], dtype=pt.long)
z = pt.tensor(np.array(mol["atomic_numbers"]), dtype=pt.long)
all_pos = (
pt.tensor(mol["conformations"], dtype=pt.float32)
pt.tensor(np.array(mol["conformations"]), dtype=pt.float32)
* self.BORH_TO_ANGSTROM
)
all_y = (
pt.tensor(mol["formation_energy"], dtype=pt.float64)
pt.tensor(np.array(mol["formation_energy"]), dtype=pt.float64)
* self.HARTREE_TO_EV
)
all_neg_dy = (
-pt.tensor(mol["dft_total_gradient"], dtype=pt.float32)
-pt.tensor(np.array(mol["dft_total_gradient"]), dtype=pt.float32)
* self.HARTREE_TO_EV
/ self.BORH_TO_ANGSTROM
)
Expand Down

0 comments on commit d3f6f22

Please sign in to comment.