Skip to content

Commit

Permalink
feat: orbital moments (#97)
Browse files Browse the repository at this point in the history
* Add new orbital moments field to Magnetism class
* Add read for spin and orbital moments
* Add selection to magnetism.moments
* Allow selection for total_moments
* Add plot with selection
* Fix viewer with nonstandard cell

* Refactor several tests using parametrized fixtures
* Require VASP 6.5 for orbital moments
  • Loading branch information
martin-schlipf committed Jun 28, 2023
1 parent 07a2aa0 commit fb935f2
Show file tree
Hide file tree
Showing 7 changed files with 333 additions and 240 deletions.
163 changes: 117 additions & 46 deletions src/py4vasp/_data/magnetism.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,22 @@

from py4vasp import exception
from py4vasp._data import base, slice_, structure
from py4vasp._util import documentation, reader
from py4vasp._util import documentation

_index_note = """Notes
_index_note = """\
Notes
-----
The index order is different compared to the raw data when noncollinear calculations
are used. This routine returns the magnetic moments as (steps, atoms, orbitals,
directions)."""

_moment_selection = """\
selection : str
If VASP was run with LORBMOM = T, the orbital moments are computed and the routine
will default to the total moments. You can specify "spin" or "orbital" to select
the individual contributions instead.
"""


@documentation.format(examples=slice_.examples("magnetism"))
class Magnetism(slice_.Mixin, base.Refinery, structure.Mixin):
Expand Down Expand Up @@ -62,13 +70,20 @@ def to_dict(self):
return {
"charges": self.charges(),
"moments": self.moments(),
**self._add_spin_and_orbital_moments(),
}

@base.data_access
@documentation.format(examples=slice_.examples("magnetism", "to_graph"))
def plot(self, supercell=None):
@documentation.format(
selection=_moment_selection, examples=slice_.examples("magnetism", "to_graph")
)
def plot(self, selection="total", supercell=None):
"""Visualize the magnetic moments as arrows inside the structure.
Paramaters
----------
{selection}
Returns
-------
Viewer3d
Expand All @@ -80,12 +95,11 @@ def plot(self, supercell=None):
{examples}
"""
if self._is_slice:
message = (
raise exception.NotImplemented(
"Visualizing magnetic moments for more than one step is not implemented"
)
raise exception.NotImplemented(message)
viewer = self._structure[self._steps].plot(supercell)
moments = self._prepare_magnetic_moments_for_plotting()
moments = self._prepare_magnetic_moments_for_plotting(selection)
if moments is not None:
viewer.show_arrows_at_atoms(moments)
return viewer
Expand All @@ -102,16 +116,22 @@ def charges(self):
{examples}
"""
moments = _Moments(self._raw_data.moments)
return moments[self._steps, 0, :, :]
self._raise_error_if_steps_out_of_bounds()
return self._raw_data.spin_moments[self._steps, 0]

@base.data_access
@documentation.format(
index_note=_index_note, examples=slice_.examples("magnetism", "moments")
selection=_moment_selection,
index_note=_index_note,
examples=slice_.examples("magnetism", "moments"),
)
def moments(self):
def moments(self, selection="total"):
"""Read the magnetic moments of the selected steps.
Parameters
----------
{selection}
Returns
-------
np.ndarray
Expand All @@ -122,16 +142,14 @@ def moments(self):
{examples}
"""
moments = _Moments(self._raw_data.moments)
_fail_if_steps_out_of_bounds(moments, self._steps)
if moments.shape[1] == 1:
self._raise_error_if_steps_out_of_bounds()
self._raise_error_if_selection_not_available(selection)
if self._only_charge:
return None
elif moments.shape[1] == 2:
return moments[self._steps, 1, :, :]
elif self._spin_polarized:
return self._collinear_moments()
else:
moments = moments[self._steps, 1:, :, :]
direction_axis = 1 if moments.ndim == 4 else 0
return np.moveaxis(moments, direction_axis, -1)
return self._noncollinear_moments(selection)

@base.data_access
@documentation.format(examples=slice_.examples("magnetism", "total_charges"))
Expand All @@ -150,11 +168,17 @@ def total_charges(self):

@base.data_access
@documentation.format(
index_note=_index_note, examples=slice_.examples("magnetism", "total_moments")
selection=_moment_selection,
index_note=_index_note,
examples=slice_.examples("magnetism", "total_moments"),
)
def total_moments(self):
def total_moments(self, selection="total"):
"""Read the total magnetic moments of the selected steps.
Parameters
----------
{selection}
Returns
-------
np.ndarray
Expand All @@ -165,19 +189,55 @@ def total_moments(self):
{examples}
"""
moments = _Moments(self._raw_data.moments)
_fail_if_steps_out_of_bounds(moments, self._steps)
if moments.shape[1] == 1:
return None
elif moments.shape[1] == 2:
return _sum_over_orbitals(self.moments())
return _sum_over_orbitals(self.moments(selection), is_vector=self._noncollinear)

@property
def _only_charge(self):
return self._raw_data.spin_moments.shape[1] == 1

@property
def _spin_polarized(self):
return self._raw_data.spin_moments.shape[1] == 2

@property
def _noncollinear(self):
return self._raw_data.spin_moments.shape[1] == 4

@property
def _has_orbital_moments(self):
return not self._raw_data.orbital_moments.is_none()

def _collinear_moments(self):
return self._raw_data.spin_moments[self._steps, 1]

def _noncollinear_moments(self, selection):
spin_moments = self._raw_data.spin_moments[self._steps, 1:]
if self._has_orbital_moments:
orbital_moments = self._raw_data.orbital_moments[self._steps, 1:]
else:
total_moments = _sum_over_orbitals(moments[self._steps, 1:, :, :])
direction_axis = 1 if total_moments.ndim == 3 else 0
return np.moveaxis(total_moments, direction_axis, -1)
orbital_moments = np.zeros_like(spin_moments)
if selection == "orbital":
moments = orbital_moments
elif selection == "spin":
moments = spin_moments
else:
moments = spin_moments + orbital_moments
direction_axis = 1 if moments.ndim == 4 else 0
return np.moveaxis(moments, direction_axis, -1)

def _add_spin_and_orbital_moments(self):
if not self._has_orbital_moments:
return {}
spin_moments = self._raw_data.spin_moments[self._steps, 1:]
orbital_moments = self._raw_data.orbital_moments[self._steps, 1:]
direction_axis = 1 if spin_moments.ndim == 4 else 0
return {
"spin_moments": np.moveaxis(spin_moments, direction_axis, -1),
"orbital_moments": np.moveaxis(orbital_moments, direction_axis, -1),
}

def _prepare_magnetic_moments_for_plotting(self):
moments = self.total_moments()
def _prepare_magnetic_moments_for_plotting(self, selection):
moments = self.total_moments(selection)
moments = _convert_moment_to_3d_vector(moments)
max_length_moments = _max_length_moments(moments)
if max_length_moments > 1e-15:
Expand All @@ -186,23 +246,34 @@ def _prepare_magnetic_moments_for_plotting(self):
else:
return None


class _Moments(reader.Reader):
def error_message(self, key, err):
key = np.array(key)
steps = key if key.ndim == 0 else key[0]
return (
f"Error reading the magnetic moments. Please check if the steps "
f"`{steps}` are properly formatted and within the boundaries. "
"Additionally, you may consider the original error message:\n" + err.args[0]
def _raise_error_if_steps_out_of_bounds(self):
try:
np.zeros(self._raw_data.spin_moments.shape[0])[self._steps]
except IndexError as error:
raise exception.IncorrectUsage(
f"Error reading the magnetic moments. Please check if the steps "
f"`{self._steps}` are properly formatted and within the boundaries."
) from error

def _raise_error_if_selection_not_available(self, selection):
if selection not in ("spin", "orbital", "total"):
raise exception.IncorrectUsage(
f"The selection {selection} is incorrect. Please check if it is spelled "
"correctly. Possible choices are total, spin, or orbital."
)
if selection != "orbital" or self._has_orbital_moments:
return
raise exception.NoData(
"There are no orbital moments in the VASP output. Please make sure that you "
"run the calculation with LORBMOM = T and LSORBIT = T."
)


def _fail_if_steps_out_of_bounds(moments, steps):
moments[steps] # try to access requested step raising an error if out of bounds


def _sum_over_orbitals(quantity):
def _sum_over_orbitals(quantity, is_vector=False):
if quantity is None:
return None
if is_vector:
return np.sum(quantity, axis=-2)
return np.sum(quantity, axis=-1)


Expand Down
19 changes: 13 additions & 6 deletions src/py4vasp/_data/viewer3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ def to_serializable(self):
return list(self.tail), list(self.tip), list(self.color), float(self.radius)


def _rotate(arrow, transformation):
return _Arrow3d(
transformation @ arrow.tail, transformation @ arrow.tip, arrow.color
)


_x_axis = _Arrow3d(tail=np.zeros(3), tip=np.array((3, 0, 0)), color=[1, 0, 0])
_y_axis = _Arrow3d(tail=np.zeros(3), tip=np.array((0, 3, 0)), color=[0, 1, 0])
_z_axis = _Arrow3d(tail=np.zeros(3), tip=np.array((0, 0, 3)), color=[0, 0, 1])
Expand Down Expand Up @@ -62,13 +68,14 @@ def from_structure(cls, structure, supercell=None):
"""
ase = structure.to_ase(supercell)
# ngl works with the standard form, so we need to store the positions in the same format
standard_cell, _ = ase.cell.standard_form()
ase.set_cell(standard_cell, scale_atoms=True)
standard_cell, transformation = ase.cell.standard_form()
ase.set_cell(standard_cell)
res = cls(nglview.show_ase(ase))
res._lengths = tuple(ase.cell.lengths())
res._angles = tuple(ase.cell.angles())
res._positions = ase.positions
res._number_cells = res._calculate_number_cells(supercell)
res._transformation = transformation
return res

def _calculate_number_cells(self, supercell):
Expand Down Expand Up @@ -110,9 +117,9 @@ def show_axes(self):
if self._axes is not None:
return
self._axes = (
self._make_arrow(_x_axis),
self._make_arrow(_y_axis),
self._make_arrow(_z_axis),
self._make_arrow(_rotate(_x_axis, self._transformation)),
self._make_arrow(_rotate(_y_axis, self._transformation)),
self._make_arrow(_rotate(_z_axis, self._transformation)),
)

def hide_axes(self):
Expand Down Expand Up @@ -141,7 +148,7 @@ def show_arrows_at_atoms(self, arrows, color=[0.1, 0.1, 0.8]):
"""
if self._positions is None:
raise exception.RefinementError("Positions of atoms are not known.")
arrows = np.repeat(arrows, self._number_cells, axis=0)
arrows = np.repeat(arrows @ self._transformation.T, self._number_cells, axis=0)
for tail, arrow in zip(self._positions, arrows):
tip = tail + arrow
arrow = _Arrow3d(tail, tip, color)
Expand Down
4 changes: 3 additions & 1 deletion src/py4vasp/_raw/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,10 @@ class Magnetism:

structure: Structure
"Structural information about the system."
moments: VaspData
spin_moments: VaspData
"Contains the charge and magnetic moments atom and orbital resolved."
orbital_moments: VaspData = NONE()
"Contains the orbital magnetization for all atoms"


@dataclasses.dataclass
Expand Down
12 changes: 7 additions & 5 deletions src/py4vasp/_raw/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@
internal_strain="results/linear_response/internal_strain",
)
#
input = "input/kpoints"
input_ = "input/kpoints"
result = "results/electron_eigenvalues"
schema.add(
raw.Kpoint,
Expand All @@ -277,7 +277,7 @@
label_indices=f"{input}/positions_labels_kpoints",
cell=Link("cell", DEFAULT_SOURCE),
)
input = "input/kpoints_opt"
input_ = "input/kpoints_opt"
result = "results/electron_eigenvalues_kpoints_opt"
schema.add(
raw.Kpoint,
Expand All @@ -290,7 +290,7 @@
label_indices=f"{input}/positions_labels_kpoints",
cell=Link("cell", DEFAULT_SOURCE),
)
input = "input/kpoints_wan"
input_ = "input/kpoints_wan"
result = "results/electron_eigenvalues_kpoints_wan"
schema.add(
raw.Kpoint,
Expand All @@ -303,7 +303,7 @@
label_indices=f"{input}/positions_labels_kpoints",
cell=Link("cell", DEFAULT_SOURCE),
)
input = "input/qpoints"
input_ = "input/qpoints"
result = "results/phonons"
schema.add(
raw.Kpoint,
Expand All @@ -320,8 +320,10 @@
#
schema.add(
raw.Magnetism,
required=raw.Version(6, 5),
structure=Link("structure", DEFAULT_SOURCE),
moments="intermediate/ion_dynamics/magnetism/moments",
spin_moments="intermediate/ion_dynamics/magnetism/spin_moments/values",
orbital_moments="intermediate/ion_dynamics/magnetism/orbital_moments/values",
)
#
group = "intermediate/pair_correlation"
Expand Down
15 changes: 10 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def kpoint(selection):

@staticmethod
def magnetism(selection):
return _magnetism(_number_components(selection))
return _magnetism(selection)

@staticmethod
def pair_correlation(selection):
Expand Down Expand Up @@ -276,7 +276,7 @@ def raw_data():
def _number_components(selection):
if selection == "collinear":
return 2
elif selection == "noncollinear":
elif selection in ("noncollinear", "orbital_moments"):
return 4
elif selection == "charge_only":
return 1
Expand Down Expand Up @@ -448,12 +448,17 @@ def _grid_kpoints(mode, labels):
return kpoints


def _magnetism(number_components):
def _magnetism(selection):
lmax = 3
number_components = _number_components(selection)
shape = (number_steps, number_components, number_atoms, lmax)
return raw.Magnetism(
structure=_Fe3O4_structure(), moments=np.arange(np.prod(shape)).reshape(shape)
magnetism = raw.Magnetism(
structure=_Fe3O4_structure(),
spin_moments=_make_data(np.arange(np.prod(shape)).reshape(shape)),
)
if selection == "orbital_moments":
magnetism.orbital_moments = _make_data(np.sqrt(magnetism.spin_moments))
return magnetism


def _single_band(projectors):
Expand Down
Loading

0 comments on commit fb935f2

Please sign in to comment.