Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consistent naming between dy and forces #121

Merged
merged 6 commits into from
Oct 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/ET-ANI1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@ dataset_root: ~/data
derivative: false
distance_influence: both
early_stopping_patience: 500
ema_alpha_dy: 1.0
ema_alpha_neg_dy: 1.0
ema_alpha_y: 1.0
embed_files: null
embedding_dimension: 128
energy_files: null
energy_weight: 1.0
y_weight: 1.0
force_files: null
force_weight: 1.0
neg_dy_weight: 1.0
inference_batch_size: 2048
load_model: null
log_dir: logs/
Expand Down
6 changes: 3 additions & 3 deletions examples/ET-MD17.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ dataset_root: ~/data
derivative: true
distance_influence: both
early_stopping_patience: 300
ema_alpha_dy: 1.0
ema_alpha_neg_dy: 1.0
ema_alpha_y: 0.05
embed_files: null
embedding_dimension: 128
energy_files: null
energy_weight: 0.2
y_weight: 0.2
force_files: null
force_weight: 0.8
neg_dy_weight: 0.8
inference_batch_size: 64
load_model: null
log_dir: logs/
Expand Down
6 changes: 3 additions & 3 deletions examples/ET-QM9.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ dataset_root: ~/data
derivative: false
distance_influence: both
early_stopping_patience: 150
ema_alpha_dy: 1.0
ema_alpha_neg_dy: 1.0
ema_alpha_y: 1.0
embed_files: null
embedding_dimension: 256
energy_files: null
energy_weight: 1.0
y_weight: 1.0
force_files: null
force_weight: 1.0
neg_dy_weight: 1.0
inference_batch_size: 128
load_model: null
log_dir: logs/
Expand Down
6 changes: 3 additions & 3 deletions examples/ET-SPICE.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ dataset_root: data
derivative: true
distance_influence: both
early_stopping_patience: 50
ema_alpha_dy: 1.0
ema_alpha_neg_dy: 1.0
ema_alpha_y: 1.0
embed_files: null
embedding_dimension: 128
energy_files: null
energy_weight: 0.5
y_weight: 0.5
force_files: null
force_weight: 0.5
neg_dy_weight: 0.5
inference_batch_size: 16
load_model: null
log_dir: logs/
Expand Down
2 changes: 1 addition & 1 deletion tests/test_dataset_comp6.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_dataset_s66x8():
)
assert pt.allclose(sample.y, pt.tensor([[-47.5919]]))
assert pt.allclose(
sample.dy,
sample.neg_dy,
pt.tensor(
[
[0.2739, -0.2190, -0.0012],
Expand Down
2 changes: 1 addition & 1 deletion tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_custom(energy, forces, num_files, tmpdir, num_samples=100):
if energy:
assert hasattr(sample, "y"), "Sample doesn't contain energy"
PhilippThoelke marked this conversation as resolved.
Show resolved Hide resolved
if forces:
assert hasattr(sample, "dy"), "Sample doesn't contain forces"
assert hasattr(sample, "neg_dy"), "Sample doesn't contain forces"


def test_hdf5_multiprocessing(tmpdir, num_entries=100):
Expand Down
2 changes: 1 addition & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def get(self, idx):
if self.energies is not None:
features["y"] = self.energies[idx].clone()
if self.forces is not None:
features["dy"] = self.forces[idx].clone()
features["neg_dy"] = self.forces[idx].clone()
return Data(**features)

def len(self):
Expand Down
2 changes: 1 addition & 1 deletion torchmdnet/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def _get_dataloader(self, dataset, stage, store_dataloader=True):

def _standardize(self):
def get_energy(batch, atomref):
if batch.y is None:
if "y" not in batch or batch.y is None:
raise MissingEnergyException()

if atomref is None:
Expand Down
39 changes: 21 additions & 18 deletions torchmdnet/datasets/ace.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@


class Ace(Dataset):

def __init__(
self,
root=None,
Expand All @@ -33,7 +32,7 @@ def __init__(
z_name,
pos_name,
y_name,
dy_name,
neg_dy_name,
q_name,
pq_name,
dp_name,
Expand All @@ -44,8 +43,8 @@ def __init__(
pos_name, mode="r", dtype=np.float32, shape=(self.z_mm.shape[0], 3)
)
self.y_mm = np.memmap(y_name, mode="r", dtype=np.float64)
self.dy_mm = np.memmap(
dy_name, mode="r", dtype=np.float32, shape=(self.z_mm.shape[0], 3)
self.neg_dy_mm = np.memmap(
neg_dy_name, mode="r", dtype=np.float32, shape=(self.z_mm.shape[0], 3)
)
self.q_mm = np.memmap(q_name, mode="r", dtype=np.int8)
self.pq_mm = np.memmap(pq_name, mode="r", dtype=np.float32)
Expand Down Expand Up @@ -78,7 +77,9 @@ def sample_iter(self, mol_ids=False):
for path in tqdm(self.raw_paths, desc="Files"):
molecules = list(h5py.File(path).items())

for i_mol, (mol_id, mol) in tqdm(enumerate(molecules), desc="Molecules", leave=False):
for i_mol, (mol_id, mol) in tqdm(
enumerate(molecules), desc="Molecules", leave=False
):

# Subsample molecules
if i_mol % self.subsample_molecules != 0:
Expand All @@ -103,8 +104,8 @@ def sample_iter(self, mol_ids=False):
assert y.shape == ()

assert conf["forces"].attrs["units"] == "eV/Å"
dy = -pt.tensor(conf["forces"], dtype=pt.float32)
assert dy.shape == pos.shape
neg_dy = -pt.tensor(conf["forces"], dtype=pt.float32)
assert neg_dy.shape == pos.shape

assert conf["partial_charges"].attrs["units"] == "e"
pq = pt.tensor(conf["partial_charges"], dtype=pt.float32)
Expand All @@ -116,11 +117,13 @@ def sample_iter(self, mol_ids=False):

# Skip samples with large forces
if self.max_gradient:
if dy.norm(dim=1).max() > float(self.max_gradient):
if neg_dy.norm(dim=1).max() > float(self.max_gradient):
continue

# Create a sample
args = dict(z=z, pos=pos, y=y.view(1, 1), dy=dy, q=q, pq=pq, dp=dp)
args = dict(
z=z, pos=pos, y=y.view(1, 1), neg_dy=neg_dy, q=q, pq=pq, dp=dp
)
if mol_ids:
args["mol_id"] = mol_id
data = Data(**args)
Expand All @@ -140,7 +143,7 @@ def processed_file_names(self):
f"{self.name}.z.mmap",
f"{self.name}.pos.mmap",
f"{self.name}.y.mmap",
f"{self.name}.dy.mmap",
f"{self.name}.neg_dy.mmap",
f"{self.name}.q.mmap",
f"{self.name}.pq.mmap",
f"{self.name}.dp.mmap",
Expand All @@ -167,7 +170,7 @@ def process(self):
z_name,
pos_name,
y_name,
dy_name,
neg_dy_name,
q_name,
pq_name,
dp_name,
Expand All @@ -182,8 +185,8 @@ def process(self):
y_mm = np.memmap(
y_name + ".tmp", mode="w+", dtype=np.float64, shape=num_all_confs
)
dy_mm = np.memmap(
dy_name + ".tmp", mode="w+", dtype=np.float32, shape=(num_all_atoms, 3)
neg_dy_mm = np.memmap(
neg_dy_name + ".tmp", mode="w+", dtype=np.float32, shape=(num_all_atoms, 3)
)
q_mm = np.memmap(q_name + ".tmp", mode="w+", dtype=np.int8, shape=num_all_confs)
pq_mm = np.memmap(
Expand All @@ -202,7 +205,7 @@ def process(self):
z_mm[i_atom:i_next_atom] = data.z.to(pt.int8)
pos_mm[i_atom:i_next_atom] = data.pos
y_mm[i_conf] = data.y
dy_mm[i_atom:i_next_atom] = data.dy
neg_dy_mm[i_atom:i_next_atom] = data.neg_dy
q_mm[i_conf] = data.q.to(pt.int8)
pq_mm[i_atom:i_next_atom] = data.pq
dp_mm[i_conf] = data.dp
Expand All @@ -216,7 +219,7 @@ def process(self):
z_mm.flush()
pos_mm.flush()
y_mm.flush()
dy_mm.flush()
neg_dy_mm.flush()
q_mm.flush()
pq_mm.flush()
dp_mm.flush()
Expand All @@ -225,7 +228,7 @@ def process(self):
os.rename(z_mm.filename, z_name)
os.rename(pos_mm.filename, pos_name)
os.rename(y_mm.filename, y_name)
os.rename(dy_mm.filename, dy_name)
os.rename(neg_dy_mm.filename, neg_dy_name)
os.rename(q_mm.filename, q_name)
os.rename(pq_mm.filename, pq_name)
os.rename(dp_mm.filename, dp_name)
Expand All @@ -241,9 +244,9 @@ def get(self, idx):
y = pt.tensor(self.y_mm[idx], dtype=pt.float32).view(
1, 1
) # It would be better to use float64, but the trainer complaints
dy = pt.tensor(self.dy_mm[atoms], dtype=pt.float32)
neg_dy = pt.tensor(self.neg_dy_mm[atoms], dtype=pt.float32)
q = pt.tensor(self.q_mm[idx], dtype=pt.long)
pq = pt.tensor(self.pq_mm[atoms], dtype=pt.float32)
dp = pt.tensor(self.dp_mm[idx], dtype=pt.float32)

return Data(z=z, pos=pos, y=y, dy=dy, q=q, pq=pq, dp=dp)
return Data(z=z, pos=pos, y=y, neg_dy=neg_dy, q=q, pq=pq, dp=dp)
57 changes: 28 additions & 29 deletions torchmdnet/datasets/ani.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,18 @@ def __init__(
self.name = self.__class__.__name__
super().__init__(root, transform, pre_transform, pre_filter)

idx_name, z_name, pos_name, y_name, dy_name = self.processed_paths
idx_name, z_name, pos_name, y_name, neg_dy_name = self.processed_paths
self.idx_mm = np.memmap(idx_name, mode="r", dtype=np.int64)
self.z_mm = np.memmap(z_name, mode="r", dtype=np.int8)
self.pos_mm = np.memmap(
pos_name, mode="r", dtype=np.float32, shape=(self.z_mm.shape[0], 3)
)
self.y_mm = np.memmap(y_name, mode="r", dtype=np.float64)
self.dy_mm = (
self.neg_dy_mm = (
np.memmap(
dy_name, mode="r", dtype=np.float32, shape=(self.z_mm.shape[0], 3)
neg_dy_name, mode="r", dtype=np.float32, shape=(self.z_mm.shape[0], 3)
)
if os.path.getsize(dy_name) > 0
if os.path.getsize(neg_dy_name) > 0
else None
)

Expand All @@ -66,7 +66,7 @@ def processed_file_names(self):
f"{self.name}.z.mmap",
f"{self.name}.pos.mmap",
f"{self.name}.y.mmap",
f"{self.name}.dy.mmap",
f"{self.name}.neg_dy.mmap",
]

def filter_and_pre_transform(self, data):
Expand All @@ -80,20 +80,19 @@ def filter_and_pre_transform(self, data):
return data

def process(self):

print("Gathering statistics...")
num_all_confs = 0
num_all_atoms = 0
for data in self.sample_iter():
num_all_confs += 1
num_all_atoms += data.z.shape[0]
has_dy = "dy" in data
has_neg_dy = "neg_dy" in data

print(f" Total number of conformers: {num_all_confs}")
print(f" Total number of atoms: {num_all_atoms}")
print(f" Forces available: {has_dy}")
print(f" Forces available: {has_neg_dy}")

idx_name, z_name, pos_name, y_name, dy_name = self.processed_paths
idx_name, z_name, pos_name, y_name, neg_dy_name = self.processed_paths
idx_mm = np.memmap(
idx_name + ".tmp", mode="w+", dtype=np.int64, shape=(num_all_confs + 1,)
)
Expand All @@ -106,12 +105,12 @@ def process(self):
y_mm = np.memmap(
y_name + ".tmp", mode="w+", dtype=np.float64, shape=(num_all_confs,)
)
dy_mm = (
neg_dy_mm = (
np.memmap(
dy_name + ".tmp", mode="w+", dtype=np.float32, shape=(num_all_atoms, 3)
neg_dy_name + ".tmp", mode="w+", dtype=np.float32, shape=(num_all_atoms, 3)
)
if has_dy
else open(dy_name, "w")
if has_neg_dy
else open(neg_dy_name, "w")
)

print("Storing data...")
Expand All @@ -123,8 +122,8 @@ def process(self):
z_mm[i_atom:i_next_atom] = data.z.to(pt.int8)
pos_mm[i_atom:i_next_atom] = data.pos
y_mm[i_conf] = data.y
if has_dy:
dy_mm[i_atom:i_next_atom] = data.dy
if has_neg_dy:
neg_dy_mm[i_atom:i_next_atom] = data.neg_dy

i_atom = i_next_atom

Expand All @@ -135,15 +134,15 @@ def process(self):
z_mm.flush()
pos_mm.flush()
y_mm.flush()
if has_dy:
dy_mm.flush()
if has_neg_dy:
neg_dy_mm.flush()

os.rename(idx_mm.filename, idx_name)
os.rename(z_mm.filename, z_name)
os.rename(pos_mm.filename, pos_name)
os.rename(y_mm.filename, y_name)
if has_dy:
os.rename(dy_mm.filename, dy_name)
if has_neg_dy:
os.rename(neg_dy_mm.filename, neg_dy_name)

def len(self):
return len(self.y_mm)
Expand All @@ -158,11 +157,11 @@ def get(self, idx):
) # It would be better to use float64, but the trainer complaints
y -= self.compute_reference_energy(z)

if self.dy_mm is None:
if self.neg_dy_mm is None:
return Data(z=z, pos=pos, y=y)
else:
dy = pt.tensor(self.dy_mm[atoms], dtype=pt.float32)
return Data(z=z, pos=pos, y=y, dy=dy)
neg_dy = pt.tensor(self.neg_dy_mm[atoms], dtype=pt.float32)
return Data(z=z, pos=pos, y=y, neg_dy=neg_dy)


class ANI1(ANIBase):
Expand Down Expand Up @@ -282,25 +281,25 @@ def sample_iter(self, mol_ids=False):
all_y = pt.tensor(
mol["wb97x_dz.energy"][:] * self.HARTREE_TO_EV, dtype=pt.float64
)
all_dy = pt.tensor(
all_neg_dy = pt.tensor(
mol["wb97x_dz.forces"][:] * self.HARTREE_TO_EV, dtype=pt.float32
)

assert all_pos.shape[0] == all_y.shape[0]
assert all_pos.shape[1] == z.shape[0]
assert all_pos.shape[2] == 3

assert all_dy.shape[0] == all_y.shape[0]
assert all_dy.shape[1] == z.shape[0]
assert all_dy.shape[2] == 3
assert all_neg_dy.shape[0] == all_y.shape[0]
assert all_neg_dy.shape[1] == z.shape[0]
assert all_neg_dy.shape[2] == 3

for pos, y, dy in zip(all_pos, all_y, all_dy):
for pos, y, neg_dy in zip(all_pos, all_y, all_neg_dy):

if y.isnan() or dy.isnan().any():
if y.isnan() or neg_dy.isnan().any():
continue

# Create a sample
args = dict(z=z, pos=pos, y=y.view(1, 1), dy=dy)
args = dict(z=z, pos=pos, y=y.view(1, 1), neg_dy=neg_dy)
if mol_ids:
args["mol_id"] = mol_id
data = Data(**args)
Expand Down