Skip to content

Commit

Permalink
Merge pull request #111 from hakonanes/improve-phase-adding-to-phase-…
Browse files Browse the repository at this point in the history
…list

Replace PhaseList.__setitem__() with PhaseList.add() method
  • Loading branch information
dnjohnstone committed Aug 20, 2020
2 parents 20bb1f2 + 26a2d46 commit d64800e
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 63 deletions.
77 changes: 53 additions & 24 deletions orix/crystal_map/phase_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,30 +586,6 @@ def __getitem__(self, key):
else:
return PhaseList(d)

def __setitem__(self, key, value):
"""Add a phase to the list with a name, point group and
structure.
"""
if key not in self.names:
# Make sure the new phase gets a new color
color_new = None
for color_name in ALL_COLORS.keys():
if color_name not in self.colors:
color_new = color_name
break

# Create new ID
if self.ids:
new_phase_id = max(self.ids) + 1
else: # `self.phase_ids` is an empty list
new_phase_id = 0

self._dict[new_phase_id] = Phase(
name=key, point_group=value, color=color_new
)
else:
raise ValueError(f"{key} is already in the phase list {self.names}.")

def __delitem__(self, key):
"""Delete a phase from the phase list.
Expand Down Expand Up @@ -716,3 +692,56 @@ def id_from_name(self, name):
if name == phase.name:
return phase_id
raise KeyError(f"'{name}' is not among the phase names {self.names}.")

def add(self, value):
"""Add phases to the end of a phase list, incrementing the phase
IDs.
Parameters
----------
value : Phase, list of Phase or another PhaseList
Phase(s) to add. If a PhaseList is added, the phase IDs in the
old list are lost.
Examples
--------
>>> from orix.crystal_map import Phase, PhaseList
>>> pl = PhaseList(names=["a", "b"], space_groups=[10, 20])
>>> pl.add(Phase("c", space_group=30))
>>> pl.add([Phase("d", space_group=40), Phase("e")])
>>> pl.add(PhaseList(names=["f", "g"], space_groups=[60, 70]))
>>> pl
Id Name Space group Point group Proper point group Color
0 a P2/m 2/m 112 tab:blue
1 b C2221 222 222 tab:orange
2 c Pnc2 mm2 211 tab:green
3 d Ama2 mm2 211 tab:red
4 e None None None tab:purple
5 f Pbcn mmm 222 tab:brown
6 g Fddd mmm 222 tab:pink
"""
if isinstance(value, Phase):
value = [value]
if isinstance(value, PhaseList):
value = [i for _, i in value]
for phase in value:
if phase.name in self.names:
raise ValueError(
f"'{phase.name}' is already in the phase list {self.names}"
)

# Ensure a new color
if phase.color in self.colors:
for color_name in ALL_COLORS.keys():
if color_name not in self.colors:
phase.color = color_name
break

# Increment the highest phase ID
if self.ids:
new_phase_id = max(self.ids) + 1
else: # `self.phase_ids` is an empty list
new_phase_id = 0

# Finally, add the phase to the list
self._dict[new_phase_id] = phase
8 changes: 4 additions & 4 deletions orix/tests/test_crystal_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,14 +346,14 @@ def test_set_phase_id_with_unknown_id(self, crystal_map, set_phase_id, index_err
with pytest.raises(IndexError, match="list index out of range"):
# `set_phase_id` ID is not in `self.phases.phase_ids`
cm[condition].phase_id = set_phase_id
_ = cm.__repr__()
_ = repr(cm)

# Add unknown ID to phase list to fix `self.__repr__()`
cm.phases["a"] = 432 # Add phase with ID 1
# Add unknown ID to phase list to fix `repr(self)`
cm.phases.add(Phase("a", point_group=432)) # Add phase with ID 1
else:
cm[condition].phase_id = set_phase_id

_ = cm.__repr__()
_ = repr(cm)

new_phase_ids = phase_ids + [set_phase_id]
new_phase_ids.sort()
Expand Down
95 changes: 60 additions & 35 deletions orix/tests/test_phase_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# along with orix. If not, see <http://www.gnu.org/licenses/>.

from diffpy.structure import Lattice, Structure
from diffpy.structure.spacegroups import GetSpaceGroup
import numpy as np
import pytest

Expand Down Expand Up @@ -235,9 +236,9 @@ class TestPhaseList:
@pytest.mark.parametrize("empty_input", [(), [], {}])
def test_init_empty_phaselist(self, empty_input):
pl = PhaseList(empty_input)
assert pl.__repr__() == "No phases."
pl["al"] = "m-3m"
assert pl.__repr__() == (
assert repr(pl) == "No phases."
pl.add(Phase("al", point_group="m-3m"))
assert repr(pl) == (
"Id Name Space group Point group Proper point group Color\n"
" 0 al None m-3m 432 tab:blue"
)
Expand Down Expand Up @@ -492,39 +493,63 @@ def test_get_from_phaselist_not_indexed(
phase_list.add_not_indexed()
assert phase_list[:3].ids == desired_ids

@pytest.mark.parametrize(
"key, value, already_there",
[("d", "m-3m", False), ("d", 432, False), ("c", 432, True),],
)
def test_set_phase_in_phaselist(self, phase_list, key, value, already_there):
if already_there:
with pytest.raises(ValueError, match=f"{key} is already in the phase "):
phase_list[key] = value
else:
desired_names = phase_list.names + [key]
desired_point_group_names = [s.name for s in phase_list.point_groups] + [
str(value)
]

phase_list[key] = value

assert phase_list.names == desired_names
assert [
s.name for s in phase_list.point_groups
] == desired_point_group_names

def test_set_phase_in_empty_phaselist(self):
def test_add_phase_in_empty_phaselist(self):
"""Add Phase to empty PhaseList."""
sg_no = 10
name = "a"
pl = PhaseList()
pl.add(Phase(name, space_group=sg_no))
assert pl.ids == [0]
assert pl.names == [name]
assert pl.space_groups == [GetSpaceGroup(sg_no)]
assert pl.structures == [Structure()]

def test_add_list_phases_to_phaselist(self):
"""Add a list of Phase objects to PhaseList, also ensuring that
unique colors are given.
"""
names = ["a", "b"]
sg_no = [10, 20]
colors = ["tab:blue", "tab:orange"]
pl = PhaseList(names=names, space_groups=sg_no)
assert pl.colors == colors

new_names = ["c", "d"]
new_sg_no = [30, 40]
pl.add([Phase(name=n, space_group=i) for n, i in zip(new_names, new_sg_no)])
assert pl.names == names + new_names
assert pl.space_groups == (
[GetSpaceGroup(i) for i in sg_no] + [GetSpaceGroup(i) for i in new_sg_no]
)
assert pl.colors == colors + ["tab:green", "tab:red"]

names = [0, 0] # Use as names
point_groups = [432, "m-3m"]
for n, s in zip(names, point_groups):
pl[n] = str(s)
def test_add_phaselist_to_phaselist(self):
"""Add a PhaseList to a PhaseList, also ensuring that new IDs are
given.
"""
names = ["a", "b"]
sg_no = [10, 20]
pl1 = PhaseList(names=names, space_groups=sg_no)
assert pl1.ids == [0, 1]

names2 = ["c", "d"]
sg_no2 = [30, 40]
ids = [4, 5]
pl2 = PhaseList(names=names2, space_groups=sg_no2, ids=ids)
pl1.add(pl2)
assert pl1.names == names + names2
assert pl1.space_groups == (
[GetSpaceGroup(i) for i in sg_no] + [GetSpaceGroup(i) for i in sg_no2]
)
assert pl1.ids == [0, 1, 2, 3]

assert pl.ids == [0, 1]
assert pl.names == [str(n) for n in names]
assert [s.name for s in pl.point_groups] == [str(s) for s in point_groups]
assert pl.structures == [Structure()] * 2
def test_add_to_phaselist_raises(self):
"""Trying to add a Phase with a name already in the PhaseList
raises a ValueError.
"""
pl = PhaseList(names=["a"])
with pytest.raises(ValueError, match="'a' is already in the phase list"):
pl.add(Phase("a"))

@pytest.mark.parametrize(
"key_del, invalid_phase, error_type, error_msg",
Expand Down Expand Up @@ -586,7 +611,7 @@ def test_deepcopy_phaselist(self, phase_list):
pl2 = phase_list.deepcopy()
assert pl2.names == names

phase_list["d"] = "m-3m"
phase_list.add(Phase("d", point_group="m-3m"))
phase_list["d"].color = "g"

assert phase_list.names == names + ["d"]
Expand All @@ -600,7 +625,7 @@ def test_deepcopy_phaselist(self, phase_list):
def test_shallowcopy_phaselist(self, phase_list):
pl2 = phase_list

phase_list["d"] = "m-3m"
phase_list.add(Phase("d", point_group="m-3m"))

assert pl2.names == phase_list.names
assert [s2.name for s2 in pl2.point_groups] == [
Expand Down

0 comments on commit d64800e

Please sign in to comment.