Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use shift for gradient calculation instead of cell #13

Merged
merged 6 commits into from Sep 21, 2021

Conversation

So-Takamoto
Copy link
Contributor

I found that the current implementation has several performance issues regarding gradient wrt. cell.
This PR modifies it. Since the changes are relatively much, I will put some comments.

Change summary:

  • Use shift for gradient instead of cell.
  • shift is now length scale instead cell unit.
  • Calculate Voigt notation style stress directly

Also, this PR contains bugfix related to sked cell.

@So-Takamoto So-Takamoto added bug Something isn't working enhancement New feature or request labels Sep 16, 2021
if cell is None:
shift = S
else:
shift = torch.mm(S, cell.detach())
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The definition of shift is different from the previous one. (current_shift = previous_shift * cell)
This is because gradient is calculated wrt. shift directly instead of cell.

# shift (n_edges, 3), cell[batch] (n_atoms, 3, 3) -> offsets (n_edges, 3)
offsets = torch.bmm(shift[:, None, :], cell[batch_edge])[:, 0]
Rj += offsets
Rj += shift
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Due to the change of shift definition, calc_disatnce() function changes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add docstring to write shift definition? I think shift is now float shift vector for distance instead of integer shift value.

pos_bohr = pos / d3_autoang # angstrom -> bohr
if cell is None:
cell_bohr = None
cell_bohr: Optional[Tensor] = None
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because of the change of calc_distances(), cell is no longer used within dftd3_module, but I kept it just in case.
Same things applied to dftd2_module.

# pos (n_atoms, 1, 3) * cell (n_atoms, 3, 3) -> (n_atoms, 3)
pos = torch.bmm(rel_pos[:, None, :].detach(), cell[batch])[:, 0]
assert isinstance(shift, Tensor)
shift.requires_grad_(True)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pos and shift is explicitly defined as gradient variables here.
The effect on cell should both applied to pos and shift. In the previous implementation, it is calculated by the redefenition of pos and making call graph.
In the new implementation, both effect is calculated explicitly later and thus pos no longer need to have calculation graph from cell. This change removes cell[batch] calculation.

if batch is None:
cell_volume = torch.det(cell).abs()
stress = torch.mm(cell.grad, cell.T) / cell_volume
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was a bug within this line.
The correct implementation should be stress = torch.mm(cell.T, cell.grad) / cell_volume (matrix order).
This can be checked by running test with skewed cell input.

)
cell_grad += torch.sum(
(shift[:, voigt_left] * shift.grad[:, voigt_right]).to(torch.float64), dim=0
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Step-by-step instruction for those lines:

The previous implementation was cell.grad. cell was used in two places, one is pos and another is distance calculation combined with shift. Therefore, we can write down the calculation of cell.grad by explicitly writing the backpropagation of cell->pos and cell->shift. The above two sentences corresponds to them.

I will use the character * for matrix multiplication here. For pos side, the forward calculation is
pos = rel_pos * cell,

therefore the backward calculation is
cell.grad = rel_pos.T * pos.grad.

Since cell.T will be multiplied from left later (line 153 in the previous implementation, see corresponding comment), the calculation will be
cell.T * cell.grad = cell.T * rel_pos.T * pos.grad = pos.T * pos.grad.

Here, the shape of pos and pos.grad are both (n_node, 3), the matrix calculation is (3, n_node) * (n_node, 3) -> (3, 3) ans it can also be written explicitly like
pos.T * pos.grad = torch.sum(pos[:, :, None] * pos.grad[:, None, :], dim=0).

The obtained value is stress tensor and from the physical consideration we know it is always symmetry. Therefore we just pick up diagonal and upper triangular values (Voigt notation). Seeing the above calculation, we can modify it so that only needed values will be calculated.
stress = (pos.T * pos.grad).view(-1)[[0, 4, 8, 5, 2, 1]] = torch.sum(pos[:, [0,1,2,1,2,0]] * pos.grad[:, [0,1,2,2,0,1]], dim=0)

This is exactly what the new implementation does. Since the summed axis length is n_node, we use FP64 to avoid numerical calculation error.

Almost same discussions can be applied to the relation between cell.grad and shift.grad.

0,
batch_edge.view(batch_edge.size()[0], 1).expand(batch_edge.size()[0], 6),
(shift[:, voigt_left] * shift.grad[:, voigt_right]).to(torch.float64),
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those calculations are almost same. The only difference is the use of scatter_add instead of the summation.
This implementation can eliminate call[batch] and cell[batch_edge], which will cause slowdown.

edge_index_abc,
shift=shift_abc,
dtype=pos.dtype,
batch_edge=batch_edge_abc,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this function has special cupy backend and does not have backprop, we cannot use shift and its derived values into this function.
So I modified its interface. Instead of returning triplet_shift directly, it returns indices j and k as edge_jk. triplet_shift can be calculated using edge_jk.

[
_offset + _j,
_offset + _k,
]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small change: since the internal array is small, I just use python list here.

"raw int64 counts, raw int64 unique, raw int64 dst, raw T shift, raw int64 batch_edge, raw int64 counts_cumsum",
"raw int64 triplet_node_index, raw T multiplicity, raw T triplet_shift, raw int64 batch_triplets",
"raw int64 counts, raw int64 unique, raw int64 dst, raw int64 edge_indices, raw int64 batch_edge, raw int64 counts_cumsum",
"raw int64 triplet_node_index, raw T multiplicity, raw int64 edge_jk, raw int64 batch_triplets",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The interface change also affects cupy implementation. shift to indices changes the data type.

g = e68.new_zeros((n_graphs,))
g.scatter_add_(0, batch_edge, e68)
g = e68.new_zeros((n_graphs,), dtype=torch.float64)
g.scatter_add_(0, batch_edge, e68.to(torch.float64))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found that larger system causes numerical error and the origin is here. Another problem is the nondeterministic behavior of scatter_add. To avoid fluctuations during MD, I think it is good to convert values into FP64 explicitly for those specific situations.

pytest.param("large", marks=[pytest.mark.slow], id="large"),
]
)
def atoms(request) -> Atoms:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added larger system test with pytest.mark.slow custom marker (won't run on CI because CI has -m "not slow" option). atoms is defined as pytest.fixture now.

@@ -105,25 +124,22 @@ def _test_calc_energy_force_stress(
abc=abc,
bidirectional=bidirectional,
)
_assert_energy_force_stress_equal(dftd3_calc, torch_dftd3_calc, atoms)
_assert_energy_force_stress_equal(dftd3_calc, torch_dftd3_calc, atoms, force_tol=force_tol)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found that FP32 sometimes fail on accuracy check. I modified the threshold for FP32 cases.

@@ -0,0 +1,3 @@
[tool:pytest]
markers =
slow: mark test as slow.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file notifies custom marker to pytest. It reduces pytest warnings.

@@ -6,7 +6,7 @@
setup_requires: List[str] = []
install_requires: List[str] = [
"ase>=3.18, <4.0.0", # Note that we require ase==3.21.1 for pytest.
"pymatgen",
"pymatgen>=2020.1.28",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pymatgen's get_neighbor_list is relatively new, so I added version limit here.

@So-Takamoto So-Takamoto marked this pull request as ready for review September 16, 2021 10:09
Copy link
Member

@corochann corochann left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! Added several comments, but implementation itself is LGTM.

@@ -20,7 +20,7 @@ main() {
docker run --runtime=nvidia --rm --volume="$(pwd)":/workspace -w /workspace \
${IMAGE} \
bash -x -c "pip install flake8 pytest pytest-cov pytest-xdist pytest-benchmark && \
pip install cupy-cuda102 pytorch-pfn-extras && \
pip install cupy-cuda102 pytorch-pfn-extras!=0.5.0 && \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What kind of error did you get with 0.5.0?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still do not get the detail, but you can see the error log from the CI log. I guess specific version&python3.6 combination caused this problem.


atoms_dict = {"mol": mol, "slab": slab, "large": large_bulk}

return atoms_dict[request.param]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a comment (low priority): If only mol or slab or large_bulk is used, it looks efficient to use if else block (switch) instead of creating all atoms everytime.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with the comment, but this is easier to unify implementations between test_torch_dftd3_calculator and test_torch_dftd3_calculator_batch (latter one uses multiple atoms). So if the speed is not the problem, I think we can keep this style so far.

calc1.reset()
atoms.calc = calc1
f1 = atoms.get_forces()
e1 = atoms.get_potential_energy()
if np.all(atoms.pbc == np.array([True, True, True])):
s1 = atoms.get_stress()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to move this block?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous implementation calculates s1 using `calc2', and thus the stress check was not correctly done.

"case4": [large_bulk],
}

return atoms_dict[request.param]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a comment (low priority): If only mol or slab or large_bulk is used, it looks efficient to use if else block (switch) instead of creating all atoms everytime.

# shift (n_edges, 3), cell[batch] (n_atoms, 3, 3) -> offsets (n_edges, 3)
offsets = torch.bmm(shift[:, None, :], cell[batch_edge])[:, 0]
Rj += offsets
Rj += shift
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add docstring to write shift definition? I think shift is now float shift vector for distance instead of integer shift value.

)
batch_triplets = None if batch_edge is None else batch_triplets

# Apply `cnthr` cutoff threshold for r_kj
idx_j, idx_k = triplet_node_index[:, 1], triplet_node_index[:, 2]
ts2 = triplet_shift[:, 2]
r_jk = calc_distances(pos, torch.stack([idx_j, idx_k], dim=0), cell, ts2, batch_triplets)
ts2 = None if shift_abc is None else shift_abc[edge_jk[:, 0]] - shift_abc[edge_jk[:, 1]]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think now you can change the variable name from ts2 to other name. shift_jk?

@@ -306,10 +315,11 @@ def edisp(
triplet_node_index[:, 1],
triplet_node_index[:, 2],
)
ts0, ts1, ts2 = triplet_shift[:, 0], triplet_shift[:, 1], triplet_shift[:, 2]
ts0 = None if shift_abc is None else -shift_abc[edge_jk[:, 0]]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shift_ij?

@@ -306,10 +315,11 @@ def edisp(
triplet_node_index[:, 1],
triplet_node_index[:, 2],
)
ts0, ts1, ts2 = triplet_shift[:, 0], triplet_shift[:, 1], triplet_shift[:, 2]
ts0 = None if shift_abc is None else -shift_abc[edge_jk[:, 0]]
ts1 = None if shift_abc is None else -shift_abc[edge_jk[:, 1]]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shift_ik?

shift = shift[is_larger][sort_inds]
edge_indices = torch.arange(shift.shape[0], dtype=torch.long, device=edge_index.device)
edge_indices = edge_indices[is_larger][sort_inds]
# shift = shift[is_larger][sort_inds]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's okay to remove comment, since shift is not used now.
you can also merge edge_indices initialization out of if else block.

_offset + _j,
_offset + _k,
]
# torch.stack([-_shift[_j], -_shift[_k], _shift[_j] - _shift[_k]], dim=0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this line

Copy link
Member

@corochann corochann left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@corochann corochann merged commit dd8644e into pfnet-research:master Sep 21, 2021
shinh pushed a commit to shinh/torch-dftd that referenced this pull request Jun 15, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants