Skip to content

Commit

Permalink
Merge pull request #118 from pc494/testing-update
Browse files Browse the repository at this point in the history
Bugfix for get_sample_local()
  • Loading branch information
dnjohnstone committed Sep 6, 2020
2 parents b8e6dfe + b4c3941 commit fdfe9a1
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 7 deletions.
8 changes: 6 additions & 2 deletions orix/sampling/sample_generators.py
Expand Up @@ -82,8 +82,12 @@ def get_sample_local(resolution=2, center=None, grid_width=10):
"""

q = uniform_SO3_sample(resolution)
grid_cosine = np.arccos(np.deg2rad(grid_width / 2))
q = q[q.a > grid_cosine]
half_angle = np.deg2rad(grid_width / 2)
half_angles = np.arccos(q.a.data)
mask = np.logical_or(
half_angles < half_angle, half_angles > (2 * np.pi - half_angle)
)
q = q[mask]
if center is not None:
q = center * q
return q
35 changes: 30 additions & 5 deletions orix/tests/test_sampling.py
Expand Up @@ -55,11 +55,36 @@ def test_uniform_SO3_sample_resolution(sample):
assert np.isclose(x, y, rtol=0.025)


def test_get_sample_local_width(fr):
""" Checks that doubling the width 8 folds the number of points """
x = get_sample_local(np.pi, fr, 15).size * 8
y = get_sample_local(np.pi, fr, 30).size
assert np.isclose(x, y, rtol=0.025)
@pytest.mark.parametrize("big,small", [(77, 52), (48, 37)])
def test_get_sample_local_width(big, small):
""" Checks that width follows the expected trend (X - Sin(X)) """
resolution = np.pi

z = get_sample_local(resolution=resolution, grid_width=small)

assert np.all(z.angle_with(Rotation([1,0,0,0])) < np.deg2rad(small))
assert np.any(z.angle_with(Rotation([1,0,0,0])) > np.deg2rad(small - 1.5*resolution))

x_size = z.size
assert x_size > 0
y_size = get_sample_local(resolution=np.pi, grid_width=big).size
x_v = np.deg2rad(small) - np.sin(np.deg2rad(small))
y_v = np.deg2rad(big) - np.sin(np.deg2rad(big))
exp = y_size / x_size
theory = y_v / x_v

# resolution/width is high, so we must be generous on tolerance
assert np.isclose(exp, theory, rtol=0.2)


@pytest.mark.parametrize("width", [60, 33])
def test_get_sample_local_center(fr, width):
""" Checks that the center argument works as expected """
resolution=8
x = get_sample_local(resolution=resolution, center=fr, grid_width=width)
assert np.all((x.angle_with(fr) < np.deg2rad(width)))
# makes sure some of our rotations are inner the outer region
assert np.any(x.angle_with(fr) > np.deg2rad(width - resolution*1.5))


@pytest.fixture(scope="session")
Expand Down

0 comments on commit fdfe9a1

Please sign in to comment.