forked from PyWavelets/pywt
-
Notifications
You must be signed in to change notification settings - Fork 1
/
wp_2d.py
49 lines (38 loc) · 1.28 KB
/
wp_2d.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import numpy as np
import matplotlib.pyplot as plt
from pywt import WaveletPacket2D
import pywt.data
arr = pywt.data.aero()
wp2 = WaveletPacket2D(arr, 'db2', 'symmetric', maxlevel=2)
# Show original figure
plt.imshow(arr, interpolation="nearest", cmap=plt.cm.gray)
path = ['d', 'v', 'h', 'a']
# Show level 1 nodes
fig = plt.figure()
for i, p2 in enumerate(path):
ax = fig.add_subplot(2, 2, i + 1)
ax.imshow(np.sqrt(np.abs(wp2[p2].data)), origin='upper',
interpolation="nearest", cmap=plt.cm.gray)
ax.set_title(p2)
# Show level 2 nodes
for p1 in path:
fig = plt.figure()
for i, p2 in enumerate(path):
ax = fig.add_subplot(2, 2, i + 1)
p1p2 = p1 + p2
ax.imshow(np.sqrt(np.abs(wp2[p1p2].data)), origin='upper',
interpolation="nearest", cmap=plt.cm.gray)
ax.set_title(p1p2)
fig = plt.figure()
i = 1
for row in wp2.get_level(2, 'freq'):
for node in row:
ax = fig.add_subplot(len(row), len(row), i)
ax.set_title("%s=(%s row, %s col)" % (
(node.path,) + wp2.expand_2d_path(node.path)))
ax.imshow(np.sqrt(np.abs(node.data)), origin='upper',
interpolation="nearest", cmap=plt.cm.gray)
i += 1
plt.show()