Skip to content

Commit

Permalink
Merge pull request #121 from torchmd/naming
Browse files Browse the repository at this point in the history
Consistent naming between dy and forces
  • Loading branch information
PhilippThoelke committed Oct 11, 2022
2 parents da67c8f + 4852f63 commit d091c8f
Show file tree
Hide file tree
Showing 18 changed files with 158 additions and 159 deletions.
6 changes: 3 additions & 3 deletions examples/ET-ANI1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@ dataset_root: ~/data
derivative: false
distance_influence: both
early_stopping_patience: 500
ema_alpha_dy: 1.0
ema_alpha_neg_dy: 1.0
ema_alpha_y: 1.0
embed_files: null
embedding_dimension: 128
energy_files: null
energy_weight: 1.0
y_weight: 1.0
force_files: null
force_weight: 1.0
neg_dy_weight: 1.0
inference_batch_size: 2048
load_model: null
log_dir: logs/
Expand Down
6 changes: 3 additions & 3 deletions examples/ET-MD17.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ dataset_root: ~/data
derivative: true
distance_influence: both
early_stopping_patience: 300
ema_alpha_dy: 1.0
ema_alpha_neg_dy: 1.0
ema_alpha_y: 0.05
embed_files: null
embedding_dimension: 128
energy_files: null
energy_weight: 0.2
y_weight: 0.2
force_files: null
force_weight: 0.8
neg_dy_weight: 0.8
inference_batch_size: 64
load_model: null
log_dir: logs/
Expand Down
6 changes: 3 additions & 3 deletions examples/ET-QM9.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ dataset_root: ~/data
derivative: false
distance_influence: both
early_stopping_patience: 150
ema_alpha_dy: 1.0
ema_alpha_neg_dy: 1.0
ema_alpha_y: 1.0
embed_files: null
embedding_dimension: 256
energy_files: null
energy_weight: 1.0
y_weight: 1.0
force_files: null
force_weight: 1.0
neg_dy_weight: 1.0
inference_batch_size: 128
load_model: null
log_dir: logs/
Expand Down
6 changes: 3 additions & 3 deletions examples/ET-SPICE.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ dataset_root: data
derivative: true
distance_influence: both
early_stopping_patience: 50
ema_alpha_dy: 1.0
ema_alpha_neg_dy: 1.0
ema_alpha_y: 1.0
embed_files: null
embedding_dimension: 128
energy_files: null
energy_weight: 0.5
y_weight: 0.5
force_files: null
force_weight: 0.5
neg_dy_weight: 0.5
inference_batch_size: 16
load_model: null
log_dir: logs/
Expand Down
2 changes: 1 addition & 1 deletion tests/test_dataset_comp6.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_dataset_s66x8():
)
assert pt.allclose(sample.y, pt.tensor([[-47.5919]]))
assert pt.allclose(
sample.dy,
sample.neg_dy,
pt.tensor(
[
[0.2739, -0.2190, -0.0012],
Expand Down
2 changes: 1 addition & 1 deletion tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_custom(energy, forces, num_files, tmpdir, num_samples=100):
if energy:
assert hasattr(sample, "y"), "Sample doesn't contain energy"
if forces:
assert hasattr(sample, "dy"), "Sample doesn't contain forces"
assert hasattr(sample, "neg_dy"), "Sample doesn't contain forces"


def test_hdf5_multiprocessing(tmpdir, num_entries=100):
Expand Down
2 changes: 1 addition & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def get(self, idx):
if self.energies is not None:
features["y"] = self.energies[idx].clone()
if self.forces is not None:
features["dy"] = self.forces[idx].clone()
features["neg_dy"] = self.forces[idx].clone()
return Data(**features)

def len(self):
Expand Down
2 changes: 1 addition & 1 deletion torchmdnet/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def _get_dataloader(self, dataset, stage, store_dataloader=True):

def _standardize(self):
def get_energy(batch, atomref):
if batch.y is None:
if "y" not in batch or batch.y is None:
raise MissingEnergyException()

if atomref is None:
Expand Down
39 changes: 21 additions & 18 deletions torchmdnet/datasets/ace.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@


class Ace(Dataset):

def __init__(
self,
root=None,
Expand All @@ -33,7 +32,7 @@ def __init__(
z_name,
pos_name,
y_name,
dy_name,
neg_dy_name,
q_name,
pq_name,
dp_name,
Expand All @@ -44,8 +43,8 @@ def __init__(
pos_name, mode="r", dtype=np.float32, shape=(self.z_mm.shape[0], 3)
)
self.y_mm = np.memmap(y_name, mode="r", dtype=np.float64)
self.dy_mm = np.memmap(
dy_name, mode="r", dtype=np.float32, shape=(self.z_mm.shape[0], 3)
self.neg_dy_mm = np.memmap(
neg_dy_name, mode="r", dtype=np.float32, shape=(self.z_mm.shape[0], 3)
)
self.q_mm = np.memmap(q_name, mode="r", dtype=np.int8)
self.pq_mm = np.memmap(pq_name, mode="r", dtype=np.float32)
Expand Down Expand Up @@ -78,7 +77,9 @@ def sample_iter(self, mol_ids=False):
for path in tqdm(self.raw_paths, desc="Files"):
molecules = list(h5py.File(path).items())

for i_mol, (mol_id, mol) in tqdm(enumerate(molecules), desc="Molecules", leave=False):
for i_mol, (mol_id, mol) in tqdm(
enumerate(molecules), desc="Molecules", leave=False
):

# Subsample molecules
if i_mol % self.subsample_molecules != 0:
Expand All @@ -103,8 +104,8 @@ def sample_iter(self, mol_ids=False):
assert y.shape == ()

assert conf["forces"].attrs["units"] == "eV/Å"
dy = -pt.tensor(conf["forces"], dtype=pt.float32)
assert dy.shape == pos.shape
neg_dy = -pt.tensor(conf["forces"], dtype=pt.float32)
assert neg_dy.shape == pos.shape

assert conf["partial_charges"].attrs["units"] == "e"
pq = pt.tensor(conf["partial_charges"], dtype=pt.float32)
Expand All @@ -116,11 +117,13 @@ def sample_iter(self, mol_ids=False):

# Skip samples with large forces
if self.max_gradient:
if dy.norm(dim=1).max() > float(self.max_gradient):
if neg_dy.norm(dim=1).max() > float(self.max_gradient):
continue

# Create a sample
args = dict(z=z, pos=pos, y=y.view(1, 1), dy=dy, q=q, pq=pq, dp=dp)
args = dict(
z=z, pos=pos, y=y.view(1, 1), neg_dy=neg_dy, q=q, pq=pq, dp=dp
)
if mol_ids:
args["mol_id"] = mol_id
data = Data(**args)
Expand All @@ -140,7 +143,7 @@ def processed_file_names(self):
f"{self.name}.z.mmap",
f"{self.name}.pos.mmap",
f"{self.name}.y.mmap",
f"{self.name}.dy.mmap",
f"{self.name}.neg_dy.mmap",
f"{self.name}.q.mmap",
f"{self.name}.pq.mmap",
f"{self.name}.dp.mmap",
Expand All @@ -167,7 +170,7 @@ def process(self):
z_name,
pos_name,
y_name,
dy_name,
neg_dy_name,
q_name,
pq_name,
dp_name,
Expand All @@ -182,8 +185,8 @@ def process(self):
y_mm = np.memmap(
y_name + ".tmp", mode="w+", dtype=np.float64, shape=num_all_confs
)
dy_mm = np.memmap(
dy_name + ".tmp", mode="w+", dtype=np.float32, shape=(num_all_atoms, 3)
neg_dy_mm = np.memmap(
neg_dy_name + ".tmp", mode="w+", dtype=np.float32, shape=(num_all_atoms, 3)
)
q_mm = np.memmap(q_name + ".tmp", mode="w+", dtype=np.int8, shape=num_all_confs)
pq_mm = np.memmap(
Expand All @@ -202,7 +205,7 @@ def process(self):
z_mm[i_atom:i_next_atom] = data.z.to(pt.int8)
pos_mm[i_atom:i_next_atom] = data.pos
y_mm[i_conf] = data.y
dy_mm[i_atom:i_next_atom] = data.dy
neg_dy_mm[i_atom:i_next_atom] = data.neg_dy
q_mm[i_conf] = data.q.to(pt.int8)
pq_mm[i_atom:i_next_atom] = data.pq
dp_mm[i_conf] = data.dp
Expand All @@ -216,7 +219,7 @@ def process(self):
z_mm.flush()
pos_mm.flush()
y_mm.flush()
dy_mm.flush()
neg_dy_mm.flush()
q_mm.flush()
pq_mm.flush()
dp_mm.flush()
Expand All @@ -225,7 +228,7 @@ def process(self):
os.rename(z_mm.filename, z_name)
os.rename(pos_mm.filename, pos_name)
os.rename(y_mm.filename, y_name)
os.rename(dy_mm.filename, dy_name)
os.rename(neg_dy_mm.filename, neg_dy_name)
os.rename(q_mm.filename, q_name)
os.rename(pq_mm.filename, pq_name)
os.rename(dp_mm.filename, dp_name)
Expand All @@ -241,9 +244,9 @@ def get(self, idx):
y = pt.tensor(self.y_mm[idx], dtype=pt.float32).view(
1, 1
) # It would be better to use float64, but the trainer complaints
dy = pt.tensor(self.dy_mm[atoms], dtype=pt.float32)
neg_dy = pt.tensor(self.neg_dy_mm[atoms], dtype=pt.float32)
q = pt.tensor(self.q_mm[idx], dtype=pt.long)
pq = pt.tensor(self.pq_mm[atoms], dtype=pt.float32)
dp = pt.tensor(self.dp_mm[idx], dtype=pt.float32)

return Data(z=z, pos=pos, y=y, dy=dy, q=q, pq=pq, dp=dp)
return Data(z=z, pos=pos, y=y, neg_dy=neg_dy, q=q, pq=pq, dp=dp)
57 changes: 28 additions & 29 deletions torchmdnet/datasets/ani.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,18 @@ def __init__(
self.name = self.__class__.__name__
super().__init__(root, transform, pre_transform, pre_filter)

idx_name, z_name, pos_name, y_name, dy_name = self.processed_paths
idx_name, z_name, pos_name, y_name, neg_dy_name = self.processed_paths
self.idx_mm = np.memmap(idx_name, mode="r", dtype=np.int64)
self.z_mm = np.memmap(z_name, mode="r", dtype=np.int8)
self.pos_mm = np.memmap(
pos_name, mode="r", dtype=np.float32, shape=(self.z_mm.shape[0], 3)
)
self.y_mm = np.memmap(y_name, mode="r", dtype=np.float64)
self.dy_mm = (
self.neg_dy_mm = (
np.memmap(
dy_name, mode="r", dtype=np.float32, shape=(self.z_mm.shape[0], 3)
neg_dy_name, mode="r", dtype=np.float32, shape=(self.z_mm.shape[0], 3)
)
if os.path.getsize(dy_name) > 0
if os.path.getsize(neg_dy_name) > 0
else None
)

Expand All @@ -66,7 +66,7 @@ def processed_file_names(self):
f"{self.name}.z.mmap",
f"{self.name}.pos.mmap",
f"{self.name}.y.mmap",
f"{self.name}.dy.mmap",
f"{self.name}.neg_dy.mmap",
]

def filter_and_pre_transform(self, data):
Expand All @@ -80,20 +80,19 @@ def filter_and_pre_transform(self, data):
return data

def process(self):

print("Gathering statistics...")
num_all_confs = 0
num_all_atoms = 0
for data in self.sample_iter():
num_all_confs += 1
num_all_atoms += data.z.shape[0]
has_dy = "dy" in data
has_neg_dy = "neg_dy" in data

print(f" Total number of conformers: {num_all_confs}")
print(f" Total number of atoms: {num_all_atoms}")
print(f" Forces available: {has_dy}")
print(f" Forces available: {has_neg_dy}")

idx_name, z_name, pos_name, y_name, dy_name = self.processed_paths
idx_name, z_name, pos_name, y_name, neg_dy_name = self.processed_paths
idx_mm = np.memmap(
idx_name + ".tmp", mode="w+", dtype=np.int64, shape=(num_all_confs + 1,)
)
Expand All @@ -106,12 +105,12 @@ def process(self):
y_mm = np.memmap(
y_name + ".tmp", mode="w+", dtype=np.float64, shape=(num_all_confs,)
)
dy_mm = (
neg_dy_mm = (
np.memmap(
dy_name + ".tmp", mode="w+", dtype=np.float32, shape=(num_all_atoms, 3)
neg_dy_name + ".tmp", mode="w+", dtype=np.float32, shape=(num_all_atoms, 3)
)
if has_dy
else open(dy_name, "w")
if has_neg_dy
else open(neg_dy_name, "w")
)

print("Storing data...")
Expand All @@ -123,8 +122,8 @@ def process(self):
z_mm[i_atom:i_next_atom] = data.z.to(pt.int8)
pos_mm[i_atom:i_next_atom] = data.pos
y_mm[i_conf] = data.y
if has_dy:
dy_mm[i_atom:i_next_atom] = data.dy
if has_neg_dy:
neg_dy_mm[i_atom:i_next_atom] = data.neg_dy

i_atom = i_next_atom

Expand All @@ -135,15 +134,15 @@ def process(self):
z_mm.flush()
pos_mm.flush()
y_mm.flush()
if has_dy:
dy_mm.flush()
if has_neg_dy:
neg_dy_mm.flush()

os.rename(idx_mm.filename, idx_name)
os.rename(z_mm.filename, z_name)
os.rename(pos_mm.filename, pos_name)
os.rename(y_mm.filename, y_name)
if has_dy:
os.rename(dy_mm.filename, dy_name)
if has_neg_dy:
os.rename(neg_dy_mm.filename, neg_dy_name)

def len(self):
return len(self.y_mm)
Expand All @@ -158,11 +157,11 @@ def get(self, idx):
) # It would be better to use float64, but the trainer complaints
y -= self.compute_reference_energy(z)

if self.dy_mm is None:
if self.neg_dy_mm is None:
return Data(z=z, pos=pos, y=y)
else:
dy = pt.tensor(self.dy_mm[atoms], dtype=pt.float32)
return Data(z=z, pos=pos, y=y, dy=dy)
neg_dy = pt.tensor(self.neg_dy_mm[atoms], dtype=pt.float32)
return Data(z=z, pos=pos, y=y, neg_dy=neg_dy)


class ANI1(ANIBase):
Expand Down Expand Up @@ -282,25 +281,25 @@ def sample_iter(self, mol_ids=False):
all_y = pt.tensor(
mol["wb97x_dz.energy"][:] * self.HARTREE_TO_EV, dtype=pt.float64
)
all_dy = pt.tensor(
all_neg_dy = pt.tensor(
mol["wb97x_dz.forces"][:] * self.HARTREE_TO_EV, dtype=pt.float32
)

assert all_pos.shape[0] == all_y.shape[0]
assert all_pos.shape[1] == z.shape[0]
assert all_pos.shape[2] == 3

assert all_dy.shape[0] == all_y.shape[0]
assert all_dy.shape[1] == z.shape[0]
assert all_dy.shape[2] == 3
assert all_neg_dy.shape[0] == all_y.shape[0]
assert all_neg_dy.shape[1] == z.shape[0]
assert all_neg_dy.shape[2] == 3

for pos, y, dy in zip(all_pos, all_y, all_dy):
for pos, y, neg_dy in zip(all_pos, all_y, all_neg_dy):

if y.isnan() or dy.isnan().any():
if y.isnan() or neg_dy.isnan().any():
continue

# Create a sample
args = dict(z=z, pos=pos, y=y.view(1, 1), dy=dy)
args = dict(z=z, pos=pos, y=y.view(1, 1), neg_dy=neg_dy)
if mol_ids:
args["mol_id"] = mol_id
data = Data(**args)
Expand Down

0 comments on commit d091c8f

Please sign in to comment.