Skip to content

Commit

Permalink
Raise TypeError for invalid cell
Browse files Browse the repository at this point in the history
  • Loading branch information
lan496 committed Feb 3, 2024
1 parent 63a7a4b commit 790b63c
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 43 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/step_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -85,4 +85,4 @@ jobs:
uses: lukka/run-cmake@v10.3
with:
workflowPreset: "${{ matrix.toolchain }}-ci"
continue-on-error: ${{ matrix.experimental && inputs.mask-experimental}}
continue-on-error: "${{ matrix.experimental && inputs.mask-experimental}}"
4 changes: 4 additions & 0 deletions ChangeLog.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ GitHub release pages and in the git history.

## \[Unreleased\]

### Python interface

- Raise `TypeError` when a given `cell` is invalid.

## v2.3.0 (27 Jan. 2024)

### Features
Expand Down
6 changes: 4 additions & 2 deletions doc/python-interface.md
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,10 @@ positions = [[0.0, 0.0, 0.0], # Al
numbers = [1, 2, 2, 2] # Al, Ni, Ni, Ni
```

Version 1.9.5 or later: The methods that use the crystal structure
will return `None` when a crystal structure is not properly given.
```{note}
When a crystal structure is not properly given, `TypeError` will be raised.
For previous versions between 1.9.5 and 2.3.0, `None` will be returned for the invalid input.
```

### Symmetry tolerance (`symprec`, `angle_tolerance`, `mag_symprec`)

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ license-files = { paths = ["COPYING"] }
readme = "python/README.rst"
dynamic = ["version"]
dependencies = [
"numpy"
"numpy",
]
authors = [
{name="Atsushi Togo", email="atz.togo@gmail.com"},
Expand Down
72 changes: 33 additions & 39 deletions python/spglib/spglib.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,7 @@ def get_symmetry(
"""
_set_no_error()

lattice, _, _, magmoms = _expand_cell(cell)
if lattice is None:
return None
_, _, _, magmoms = _expand_cell(cell)

if magmoms is None:
# Get symmetry operations without on-site tensors (i.e. normal crystal)
Expand Down Expand Up @@ -380,11 +378,8 @@ def get_magnetic_symmetry(
_set_no_error()

lattice, positions, numbers, magmoms = _expand_cell(cell)
if lattice is None:
return None
if magmoms is None:
warnings.warn("Specify magnetic moments in cell.")
return None
raise TypeError("Specify magnetic moments in cell.")

max_size = len(positions) * 96
rotations = np.zeros((max_size, 3, 3), dtype="intc", order="C")
Expand Down Expand Up @@ -694,8 +689,6 @@ def get_symmetry_dataset(
_set_no_error()

lattice, positions, numbers, _ = _expand_cell(cell)
if lattice is None:
return None

spg_ds = spg.dataset(
lattice,
Expand All @@ -719,8 +712,6 @@ def get_symmetry_layerdataset(cell, aperiodic_dir=2, symprec=1e-5):
_set_no_error()

lattice, positions, numbers, _ = _expand_cell(cell)
if lattice is None:
return None

spg_ds = spg.layerdataset(
lattice,
Expand Down Expand Up @@ -830,8 +821,6 @@ def get_magnetic_symmetry_dataset(
_set_no_error()

lattice, positions, numbers, magmoms = _expand_cell(cell)
if lattice is None:
return None

tensor_rank = magmoms.ndim - 1

Expand Down Expand Up @@ -1303,8 +1292,6 @@ def standardize_cell(
_set_no_error()

lattice, _positions, _numbers, _ = _expand_cell(cell)
if lattice is None:
return None

# Atomic positions have to be specified by scaled positions for spglib.
num_atom = len(_positions)
Expand Down Expand Up @@ -1350,8 +1337,6 @@ def refine_cell(cell, symprec=1e-5, angle_tolerance=-1.0):
_set_no_error()

lattice, _positions, _numbers, _ = _expand_cell(cell)
if lattice is None:
return None

# Atomic positions have to be specified by scaled positions for spglib.
num_atom = len(_positions)
Expand Down Expand Up @@ -1394,8 +1379,6 @@ def find_primitive(cell, symprec=1e-5, angle_tolerance=-1.0):
_set_no_error()

lattice, positions, numbers, _ = _expand_cell(cell)
if lattice is None:
return None

num_atom_prim = spg.primitive(lattice, positions, numbers, symprec, angle_tolerance)
_set_error_message()
Expand Down Expand Up @@ -1981,36 +1964,47 @@ def get_error_message():
return spglib_error.message


def _expand_cell(cell):
def _expand_cell(
cell,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray | None]:
lattice = np.array(np.transpose(cell[0]), dtype="double", order="C")
positions = np.array(cell[1], dtype="double", order="C")
numbers = np.array(cell[2], dtype="intc")
if len(cell) > 3:
if len(cell) == 4:
magmoms = np.array(cell[3], order="C", dtype="double")
else:
elif len(cell) == 3:
magmoms = None

if _check(lattice, positions, numbers, magmoms):
return (lattice, positions, numbers, magmoms)
else:
return (None, None, None, None)

raise TypeError("cell has to be a tuple of 3 or 4 elements.")

def _check(lattice, positions, numbers, magmoms):
# Sanity check
if lattice.shape != (3, 3):
return False
if positions.ndim != 2:
return False
if positions.shape[1] != 3:
return False
raise TypeError("lattice has to be a (3, 3) array.")
if not (positions.ndim == 2 and positions.shape[1] == 3):
raise TypeError("positions has to be a (num_atoms, 3) array.")
num_atoms = positions.shape[0]
if numbers.ndim != 1:
return False
if len(numbers) != positions.shape[0]:
return False
raise TypeError("numbers has to be a (num_atoms,) array.")
if len(numbers) != num_atoms:
raise TypeError("numbers has to have the same number of atoms as positions.")
if magmoms is not None:
if len(magmoms) != len(numbers):
return False
return True
if len(magmoms) != num_atoms:
raise TypeError(
"magmoms has to have the same number of atoms as positions."
)
if magmoms.ndim == 1:
# collinear
pass
elif magmoms.ndim == 2:
# non-collinear
if magmoms.shape[1] != 3:
raise TypeError(
"non-collinear magmoms has to be a (num_atoms, 3) array."
)
else:
raise TypeError("magmoms has to be a 1D or 2D array.")

return (lattice, positions, numbers, magmoms)


def _set_error_message():
Expand Down
8 changes: 8 additions & 0 deletions test/functional/python/test_cell.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import pytest
from spglib.spglib import _expand_cell


def test_expand_cell():
with pytest.raises(TypeError) as e:
_expand_cell(None)
assert e.value

0 comments on commit 790b63c

Please sign in to comment.