New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
use shift for gradient calculation instead of cell #13
use shift for gradient calculation instead of cell #13
Conversation
torch_dftd/torch_dftd3_calculator.py
Outdated
if cell is None: | ||
shift = S | ||
else: | ||
shift = torch.mm(S, cell.detach()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The definition of shift
is different from the previous one. (current_shift = previous_shift * cell
)
This is because gradient is calculated wrt. shift
directly instead of cell.
torch_dftd/functions/distance.py
Outdated
# shift (n_edges, 3), cell[batch] (n_atoms, 3, 3) -> offsets (n_edges, 3) | ||
offsets = torch.bmm(shift[:, None, :], cell[batch_edge])[:, 0] | ||
Rj += offsets | ||
Rj += shift |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Due to the change of shift
definition, calc_disatnce()
function changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add docstring to write shift definition? I think shift
is now float shift vector for distance instead of integer shift value.
pos_bohr = pos / d3_autoang # angstrom -> bohr | ||
if cell is None: | ||
cell_bohr = None | ||
cell_bohr: Optional[Tensor] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because of the change of calc_distances()
, cell
is no longer used within dftd3_module
, but I kept it just in case.
Same things applied to dftd2_module
.
torch_dftd/nn/base_dftd_module.py
Outdated
# pos (n_atoms, 1, 3) * cell (n_atoms, 3, 3) -> (n_atoms, 3) | ||
pos = torch.bmm(rel_pos[:, None, :].detach(), cell[batch])[:, 0] | ||
assert isinstance(shift, Tensor) | ||
shift.requires_grad_(True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pos
and shift
is explicitly defined as gradient variables here.
The effect on cell
should both applied to pos
and shift
. In the previous implementation, it is calculated by the redefenition of pos
and making call graph.
In the new implementation, both effect is calculated explicitly later and thus pos
no longer need to have calculation graph from cell
. This change removes cell[batch]
calculation.
if batch is None: | ||
cell_volume = torch.det(cell).abs() | ||
stress = torch.mm(cell.grad, cell.T) / cell_volume |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a bug within this line.
The correct implementation should be stress = torch.mm(cell.T, cell.grad) / cell_volume
(matrix order).
This can be checked by running test with skewed cell input.
) | ||
cell_grad += torch.sum( | ||
(shift[:, voigt_left] * shift.grad[:, voigt_right]).to(torch.float64), dim=0 | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Step-by-step instruction for those lines:
The previous implementation was cell.grad
. cell
was used in two places, one is pos
and another is distance calculation combined with shift
. Therefore, we can write down the calculation of cell.grad
by explicitly writing the backpropagation of cell
->pos
and cell
->shift
. The above two sentences corresponds to them.
I will use the character *
for matrix multiplication here. For pos
side, the forward calculation is
pos = rel_pos * cell
,
therefore the backward calculation is
cell.grad = rel_pos.T * pos.grad
.
Since cell.T
will be multiplied from left later (line 153 in the previous implementation, see corresponding comment), the calculation will be
cell.T * cell.grad = cell.T * rel_pos.T * pos.grad = pos.T * pos.grad
.
Here, the shape of pos
and pos.grad
are both (n_node, 3)
, the matrix calculation is (3, n_node) * (n_node, 3) -> (3, 3)
ans it can also be written explicitly like
pos.T * pos.grad = torch.sum(pos[:, :, None] * pos.grad[:, None, :], dim=0)
.
The obtained value is stress tensor and from the physical consideration we know it is always symmetry. Therefore we just pick up diagonal and upper triangular values (Voigt notation). Seeing the above calculation, we can modify it so that only needed values will be calculated.
stress = (pos.T * pos.grad).view(-1)[[0, 4, 8, 5, 2, 1]] = torch.sum(pos[:, [0,1,2,1,2,0]] * pos.grad[:, [0,1,2,2,0,1]], dim=0)
This is exactly what the new implementation does. Since the summed axis length is n_node
, we use FP64 to avoid numerical calculation error.
Almost same discussions can be applied to the relation between cell.grad
and shift.grad
.
0, | ||
batch_edge.view(batch_edge.size()[0], 1).expand(batch_edge.size()[0], 6), | ||
(shift[:, voigt_left] * shift.grad[:, voigt_right]).to(torch.float64), | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Those calculations are almost same. The only difference is the use of scatter_add instead of the summation.
This implementation can eliminate call[batch]
and cell[batch_edge]
, which will cause slowdown.
edge_index_abc, | ||
shift=shift_abc, | ||
dtype=pos.dtype, | ||
batch_edge=batch_edge_abc, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this function has special cupy backend and does not have backprop, we cannot use shift
and its derived values into this function.
So I modified its interface. Instead of returning triplet_shift
directly, it returns indices j
and k
as edge_jk
. triplet_shift
can be calculated using edge_jk
.
[ | ||
_offset + _j, | ||
_offset + _k, | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small change: since the internal array is small, I just use python list here.
"raw int64 counts, raw int64 unique, raw int64 dst, raw T shift, raw int64 batch_edge, raw int64 counts_cumsum", | ||
"raw int64 triplet_node_index, raw T multiplicity, raw T triplet_shift, raw int64 batch_triplets", | ||
"raw int64 counts, raw int64 unique, raw int64 dst, raw int64 edge_indices, raw int64 batch_edge, raw int64 counts_cumsum", | ||
"raw int64 triplet_node_index, raw T multiplicity, raw int64 edge_jk, raw int64 batch_triplets", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The interface change also affects cupy implementation. shift
to indices changes the data type.
g = e68.new_zeros((n_graphs,)) | ||
g.scatter_add_(0, batch_edge, e68) | ||
g = e68.new_zeros((n_graphs,), dtype=torch.float64) | ||
g.scatter_add_(0, batch_edge, e68.to(torch.float64)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I found that larger system causes numerical error and the origin is here. Another problem is the nondeterministic behavior of scatter_add
. To avoid fluctuations during MD, I think it is good to convert values into FP64 explicitly for those specific situations.
pytest.param("large", marks=[pytest.mark.slow], id="large"), | ||
] | ||
) | ||
def atoms(request) -> Atoms: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added larger system test with pytest.mark.slow
custom marker (won't run on CI because CI has -m "not slow"
option). atoms
is defined as pytest.fixture
now.
@@ -105,25 +124,22 @@ def _test_calc_energy_force_stress( | |||
abc=abc, | |||
bidirectional=bidirectional, | |||
) | |||
_assert_energy_force_stress_equal(dftd3_calc, torch_dftd3_calc, atoms) | |||
_assert_energy_force_stress_equal(dftd3_calc, torch_dftd3_calc, atoms, force_tol=force_tol) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I found that FP32 sometimes fail on accuracy check. I modified the threshold for FP32 cases.
@@ -0,0 +1,3 @@ | |||
[tool:pytest] | |||
markers = | |||
slow: mark test as slow. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file notifies custom marker to pytest. It reduces pytest warnings.
@@ -6,7 +6,7 @@ | |||
setup_requires: List[str] = [] | |||
install_requires: List[str] = [ | |||
"ase>=3.18, <4.0.0", # Note that we require ase==3.21.1 for pytest. | |||
"pymatgen", | |||
"pymatgen>=2020.1.28", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pymatgen's get_neighbor_list
is relatively new, so I added version limit here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you! Added several comments, but implementation itself is LGTM.
@@ -20,7 +20,7 @@ main() { | |||
docker run --runtime=nvidia --rm --volume="$(pwd)":/workspace -w /workspace \ | |||
${IMAGE} \ | |||
bash -x -c "pip install flake8 pytest pytest-cov pytest-xdist pytest-benchmark && \ | |||
pip install cupy-cuda102 pytorch-pfn-extras && \ | |||
pip install cupy-cuda102 pytorch-pfn-extras!=0.5.0 && \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What kind of error did you get with 0.5.0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still do not get the detail, but you can see the error log from the CI log. I guess specific version&python3.6 combination caused this problem.
|
||
atoms_dict = {"mol": mol, "slab": slab, "large": large_bulk} | ||
|
||
return atoms_dict[request.param] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a comment (low priority): If only mol or slab or large_bulk is used, it looks efficient to use if else block (switch) instead of creating all atoms everytime.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with the comment, but this is easier to unify implementations between test_torch_dftd3_calculator
and test_torch_dftd3_calculator_batch
(latter one uses multiple atoms). So if the speed is not the problem, I think we can keep this style so far.
calc1.reset() | ||
atoms.calc = calc1 | ||
f1 = atoms.get_forces() | ||
e1 = atoms.get_potential_energy() | ||
if np.all(atoms.pbc == np.array([True, True, True])): | ||
s1 = atoms.get_stress() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need to move this block?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The previous implementation calculates s1
using `calc2', and thus the stress check was not correctly done.
"case4": [large_bulk], | ||
} | ||
|
||
return atoms_dict[request.param] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a comment (low priority): If only mol or slab or large_bulk is used, it looks efficient to use if else block (switch) instead of creating all atoms everytime.
torch_dftd/functions/distance.py
Outdated
# shift (n_edges, 3), cell[batch] (n_atoms, 3, 3) -> offsets (n_edges, 3) | ||
offsets = torch.bmm(shift[:, None, :], cell[batch_edge])[:, 0] | ||
Rj += offsets | ||
Rj += shift |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add docstring to write shift definition? I think shift
is now float shift vector for distance instead of integer shift value.
torch_dftd/functions/dftd3.py
Outdated
) | ||
batch_triplets = None if batch_edge is None else batch_triplets | ||
|
||
# Apply `cnthr` cutoff threshold for r_kj | ||
idx_j, idx_k = triplet_node_index[:, 1], triplet_node_index[:, 2] | ||
ts2 = triplet_shift[:, 2] | ||
r_jk = calc_distances(pos, torch.stack([idx_j, idx_k], dim=0), cell, ts2, batch_triplets) | ||
ts2 = None if shift_abc is None else shift_abc[edge_jk[:, 0]] - shift_abc[edge_jk[:, 1]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think now you can change the variable name from ts2
to other name. shift_jk
?
torch_dftd/functions/dftd3.py
Outdated
@@ -306,10 +315,11 @@ def edisp( | |||
triplet_node_index[:, 1], | |||
triplet_node_index[:, 2], | |||
) | |||
ts0, ts1, ts2 = triplet_shift[:, 0], triplet_shift[:, 1], triplet_shift[:, 2] | |||
ts0 = None if shift_abc is None else -shift_abc[edge_jk[:, 0]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shift_ij
?
torch_dftd/functions/dftd3.py
Outdated
@@ -306,10 +315,11 @@ def edisp( | |||
triplet_node_index[:, 1], | |||
triplet_node_index[:, 2], | |||
) | |||
ts0, ts1, ts2 = triplet_shift[:, 0], triplet_shift[:, 1], triplet_shift[:, 2] | |||
ts0 = None if shift_abc is None else -shift_abc[edge_jk[:, 0]] | |||
ts1 = None if shift_abc is None else -shift_abc[edge_jk[:, 1]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shift_ik
?
torch_dftd/functions/triplets.py
Outdated
shift = shift[is_larger][sort_inds] | ||
edge_indices = torch.arange(shift.shape[0], dtype=torch.long, device=edge_index.device) | ||
edge_indices = edge_indices[is_larger][sort_inds] | ||
# shift = shift[is_larger][sort_inds] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's okay to remove comment, since shift
is not used now.
you can also merge edge_indices initialization out of if else block.
torch_dftd/functions/triplets.py
Outdated
_offset + _j, | ||
_offset + _k, | ||
] | ||
# torch.stack([-_shift[_j], -_shift[_k], _shift[_j] - _shift[_k]], dim=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove this line
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Refactor 230602
I found that the current implementation has several performance issues regarding gradient wrt.
cell
.This PR modifies it. Since the changes are relatively much, I will put some comments.
Change summary:
shift
for gradient instead ofcell
.shift
is now length scale instead cell unit.Also, this PR contains bugfix related to sked cell.