-
Notifications
You must be signed in to change notification settings - Fork 75
/
filter_test.py
96 lines (70 loc) · 2.65 KB
/
filter_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import pynbody
import numpy as np
import numpy.testing as npt
def setup():
global f
f = pynbody.new(1000)
f['pos'] = np.random.normal(scale=1.0, size=f['pos'].shape)
f['vel'] = np.random.normal(scale=1.0, size=f['vel'].shape)
f['mass'] = np.random.uniform(1.0, 10.0, size=f['mass'].shape)
f['pos'].units = 'kpc'
f['vel'].units = 'km s^-1'
f['mass'].units = 'Msol'
def teardown():
global f
del f
def test_sphere():
global f
sp = f[pynbody.filt.Sphere(0.5)]
assert (sp.get_index_list(f) == np.where(f['r'] < 0.5)[0]).all()
assert sp['x'].max() < 0.5
assert sp['x'].min() > -0.5
assert sp['r'].max() < 0.5
sp_units = f[pynbody.filt.Sphere('500 pc')]
assert len(sp_units.intersect(sp)) == len(sp)
def test_passfilters():
global f
hp = f[pynbody.filt.HighPass('mass', 5)]
lp = f[pynbody.filt.LowPass('mass', 5)]
bp = f[pynbody.filt.BandPass('mass', 2, 7)]
assert len(hp) > 0
assert len(lp) > 0
assert len(bp) > 0
assert len(hp.intersect(lp)) == 0
assert (hp.get_index_list(f) == np.where(f['mass'] > 5)[0]).all()
assert (lp.get_index_list(f) == np.where(f['mass'] < 5)[0]).all()
assert (bp.get_index_list(f) == np.where(
(f['mass'] > 2) * (f['mass'] < 7))[0]).all()
def test_logic():
global f
comp = f[pynbody.filt.BandPass('mass', 2, 7)]
and_test = f[
pynbody.filt.HighPass('mass', 2) & pynbody.filt.LowPass('mass', 7)]
assert and_test == comp
comp = f[pynbody.filt.LowPass('mass', 2)]
or_test = f[
(pynbody.filt.LowPass('mass', 1) | pynbody.filt.BandPass('mass', 1, 2))]
assert or_test == comp
comp = f[pynbody.filt.BandPass('mass', 2, 7)]
not_test = f[~(pynbody.filt.BandPass('mass', 2, 7))]
assert comp.union(not_test) == f
assert len(comp.intersect(not_test)) == 0
assert len(comp)+len(not_test)==len(f)
assert len(comp)!=0
assert len(not_test)!=0
def test_family_filter():
f = pynbody.new(dm=100,gas=100)
f_dm = f.dm
f_dm_filter = f[pynbody.filt.FamilyFilter(pynbody.family.dm)]
f_gas = f.gas
f_gas_filter = f[pynbody.filt.FamilyFilter(pynbody.family.gas)]
assert (f_dm.get_index_list(f) == f_dm_filter.get_index_list(f)).all()
assert (f_gas.get_index_list(f) == f_gas_filter.get_index_list(f)).all()
def test_hashing():
X = {}
X[pynbody.filt.Sphere('100 kpc')] = 5
X[pynbody.filt.FamilyFilter(pynbody.family.gas)] = 10
assert X.get(pynbody.filt.Sphere('100 kpc'), None) == 5
assert X.get(pynbody.filt.FamilyFilter(pynbody.family.gas),None)==10
with npt.assert_raises(KeyError):
X[pynbody.filt.FamilyFilter(pynbody.family.dm)]