Design proposal: unified NeighborGraph neighbor list (edges + optional angles) — one contract across backends and scenarios
#4
Replies: 3 comments 2 replies
-
The proposal describes the graph as carrying all neighbors within For newly trained graph-native models, using all neighbors within A possible distinction might be:
I think this distinction should be explicit in the contract.
The proposal says that in multi-rank MPI,
I agree that multi-rank can be out of scope for the first implementation, but the contract should probably state the invariants clearly enough that a later MPI implementation will not require changing the graph fields.
For new graph-native implementations, the decomposition looks convincing. But for legacy parity, there are many details that are part of the effective descriptor contract:
Even if the graph representation is mathematically complete, preserving old model outputs may require descriptor-specific adapters or compatibility modes. It would be useful to distinguish “graph-native descriptor design” from “legacy descriptor parity mode”, and to specify which level of numerical equivalence is expected during migration.
Using static
I think the contract is good, but without a concrete capacity-management story, the JAX/Paddle path may be hard to use reliably in production training. |
Beta Was this translation helpful? Give feedback.
-
|
Thanks for writing this up. I like the overall direction: a flat A few contract-level details look important to clarify before this becomes the basis for PR splitting.
This is the right lower-interface direction, but I would state the ownership boundary very explicitly:
If
For the attention descriptors,
This needs careful wording. In multi-rank mode, force scatter inside the graph lower naturally lands on the extended local+halo node domain. To return owned forces, halo contributions must be reverse-accumulated to owners through the communication layer. Energy/virial reduction should be over owned centers only. I would avoid saying “scatters to local atoms only” without distinguishing single-rank local nodes from multi-rank extended nodes.
Global virial from edge gradients is fine, but atomic virial is convention-dependent. “Split between endpoints” may not match the existing DeepMD atomic virial semantics. I would make this an explicit compatibility decision rather than baking it into the new contract implicitly. Also, requiring
The angle abstraction is good, especially because it keeps
I agree with the long-term claim, but I would phrase this more as a migration checklist. Existing descriptors have legacy assumptions around type-wise Finally, since the stated goal includes C/C++ inference, the ABI details should be specified early: index dtype, mask dtype, memory layout/contiguity, whether padded indices must be in-range, ownership of buffers, and whether Overall: +1 on the architecture. I would just prefer to make these invariants explicit in the — Authored by OpenClaw 2026.6.8 (844f405) (model: custom-chat-jinzhezeng-group/gpt-5.5) |
Beta Was this translation helpful? Give feedback.
-
Design updated — summary of changes from the reviewThanks to everyone's feedback above (@iProzd, @njzjz-bot). The contract has converged on several points; folding them in here so the thread reflects the current design. The top post is now stale on the items below — most importantly the atomic virial. Substantive changes1. Atomic virial → full-to-
2. Additions
PR splitting remains out of scope for this proposal — the intent here is to converge on the contract. |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
Summary
This is a design proposal for a single, backend-agnostic neighbor-list contract —
NeighborGraph— that replaces the divergent edge-schema work and serves every consumption scenario with one data structure and one set of operations.Goals
Dim("nedge")(exactΣ nedge_f).E_maxcapacity (jit needs static shapes).Prior art: PR deepmodeling#5491 (pluggable
NeighborListstrategy + vesin), PR deepmodeling#5562 (edge schema for DPA4/SeZM — the reference this generalizes/cleans up), PR deepmodeling#5518 (edge-based force/virial).First principles
An energy MLIP computes
E = Σ_i E_i, whereE_idepends only on the relative positions of atomi's neighbors withinrcut. The minimal, complete, ghost-free description is the edge list:(src, dst)with a displacement.dst= center atom,src= neighbor. With PBC,srcis a local atom and periodicity is an integer image shiftS:edge_vec = r_src + S·box − r_dst. Bothsrc, dst ∈ [0, N)are local — no ghosts.Eonly throughedge_vec⇒edge_vecis the single autograd leaf.edge_vec = r_src − r_dst):g_e = ∂E/∂edge_vec_e;F_k = segment_sum(g_e, dst, N) − segment_sum(g_e, src, N)— scatters to local atoms only.W = −Σ_e g_e ⊗ edge_vec_e; atomic =g_e ⊗ edge_vec_esplit between endpoints. No absolute coordinates needed.Consequence — the lower interface needs no coordinates and no box. Only
atype(per-node, for type embedding) plus the edge list.coord/boxare inputs to the builder, not to the model's lower interface.The contract:
NeighborGraphRelative to the current edge schema, the following are removed or relocated:
coordremoved — geometry isedge_vec; output shape /num_segmentscome fromn_node.nallscalar removed — scatter domain = node domain =N(fromn_node).edge_scatter_indexremoved — it was only needed for the hybrid "local-mapped message + extended scatter" choice. A clean design uses one index space: single-rank(i,j,S)local; multi-rank extended for both message and scatter — they coincide either way.atyperelocated to a primary model input (gathered over[0, N)), not part of the nlist.Derived for free:
frame_id = repeat(arange(nf), n_node)(node→frame map, for per-framefparamgather and per-frame energy reduction).n_nodedoes double duty (sizes + frame_id).Ghost completeness. Single-rank (incl. LAMMPS single-rank, Python PBC) is ghost-free — the builder converts host ghosts → local
(i,j,S)(map ghost→owner, recover shift), soN = nloc, all nodes local,n_local=None, the core fields are complete. Multi-rank MPI cannot eliminate ghosts (cross-rank owners): the node space is extendedN = nall(local+halo), andn_localis required — energy reduces over owned only, message aggregation runs over allN. Multi-rank therefore needsn_localplus acomm_dicthalo layer (as in today's DPA2/DPA3); it is out of scope for the core contract.Angle extension (3-body)
DPA3's
repflowshas a genuine three-body representationa^{ijk}(its owna_rcut < rcut,a_sel,update_angle; ause_dynamic_selragged mode vs a staticnf*nloc*a_sel*a_seldense capacity).se_t/se_t_tebdare the same structure. An angle is a pair of edges sharing a center.angle_index (2, A)+angle_mask (A,);angle_indexpoints into edges[0, E)(reusesedge_vec; the center isdst(edge_a) == dst(edge_b), derivable). Flat ragged angle axisA(torch dynamic / jaxA_max+ mask — tighter than the densea_sel²).edge_vecstays the only geometry leaf — an angle's geometry (cos θ_ijk, …) is a derived function of its two referenced edge vectors, sograd(E, edge_vec)still captures angle coordinate-dependence. Force/virial are unchanged; the angle list adds topology only, no new leaf, no new backward pass.edge_a; angle→node by the shared center).angle_index=Noneand edge-only descriptors ignore it.This is the k-body hierarchy: 2-body = edges (node pairs), 3-body = angles (edge pairs sharing a node), all sharing one
edge_vecgeometry leaf.Length policy:
GraphLayout(the only torch/jax difference)[ real entries, frame-major contiguous | padding suffix ]; masks are a prefix thresholdarange < real.None, exact counts + guard; exportDim("nedge", min=min_edges).edge_capacity = E_max(+angle_capacity); node/frameNone(derived). jax ragged: all set.Σnedge > E_max, etc.) → caller grows the capacity and recompiles outside jit (as jax-md'sdid_buffer_overflow), or the data loader pre-scans / bin-packs.The data structure, the consuming code, and the math are identical across all scenarios; only the edge-axis length policy varies, and
edge_maskabsorbs it. This matches the convergent approach across jax-md / jraph / mace-jax / nequip-jax (flat sender/receiver list + capacity/E_max+segment_sum).Two primitives (the real new dpmodel work)
edge_segment_sum(values, index, n_segments)— a backend-dispatched scatter (jax.ops.segment_sum, torchindex_add_/scatter_add_, numpynp.add.at, paddlesegment_sum). Padding handled byvalues * edge_maskbefore the scatter. Used for dst-aggregation, angle aggregation, per-frame energy reduction, and the force/virial scatter. This is the one non-array-API primitive the whole jax MLIP ecosystem pays for.edge_energy_deriv(E, edge_vec, src, dst, edge_mask, ...)— model-level force/virial. Onegrad(E, edge_vec) = g_e; force, atomic virial, and global virial all fall out of that single backward. (atom_virialis therefore always computed — it is the precursor to the global virial, not an add-on, so there is nodo_atomic_virialgate in the edge path.)Lower interface — a new method, legacy
forward_common_loweruntouchedThe edge-based lower gets a distinct name,
forward_common_lower_graph; the legacy quartetforward_common_lower(extended_coord, extended_atype, nlist, mapping)is left unchanged for non-edge descriptors. Dispatch is by model type;.pt2metadata recordslower_input_kind ∈ {nlist, graph}. Aneighbor_graph_from_extendedadapter bridges a host-supplied quartet into the graph lower.Node tensors flatten from
(nf, nloc, …)to(N, …)+n_node; edge tensors are already flat; per-frame tensors stay(nf, …). The publicforward(coord, atype, box)is unchanged — it builds the graph internally and ravels/unravels at the I/O boundary.Descriptor portability — all families port
Every deepmd descriptor is translation-invariant and depends on the environment only through
r_ij = edge_vec, which the edge list carries completely; anything per-center is recoverable by grouping edges bydst. Three computation-pattern classes:se_e2_a,se_r,dpa2,dpa3(edge channel), DPA4/SeZM. Each edge contributes independently; aggregate withsegment_sum. Evense_e2_a's bilinearD = (GᵀR)(RᵀG)factorizes (T_i = segment_sum(R_ij ⊗ G_ij),D_i = T_iᵀT_i). Trivial.dpa1,se_atten_v2. Per-centersegment_softmax(logits, dst)(as in jraph attention).se_t,se_t_tebd, the dpa3 angle channel. Served first-class by the optional angle list (built once, consumed viaedge_segment_sum), not ad-hoc triplet enumeration inside each descriptor.The port can be incremental — the legacy quartet lower coexists — and once all families are on the graph form, the quartet lower can be retired.
Decisions baked in
edge_vecsign =r_src − r_dst(neighbor − center; matches the existing env_mat convention).edge_scatter_indexremoved (hybrid-only artifact).nallscalar removed (scatter domain =N).coord/boxremoved from the lower interface.atom_virialalways computed, nodo_atomic_virialgate.NeighborGraph+edge_segment_sum(commit to asegment_sumshim in dpmodel).n_node = [nloc]*nfspecial case.PR splitting and the concrete implementation plan are intentionally out of scope for this proposal — the goal here is to agree on the contract and the architecture. Feedback welcome.
Beta Was this translation helpful? Give feedback.
All reactions