-
-
Notifications
You must be signed in to change notification settings - Fork 1.3k
/
dwt.py
153 lines (125 loc) · 4.85 KB
/
dwt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
"""Discrete wavelet transform."""
import math
import numpy as np
import pandas as pd
from sktime.datatypes import convert
from sktime.transformations.base import BaseTransformer
__author__ = ["vnicholson1"]
class DWTTransformer(BaseTransformer):
"""Discrete Wavelet Transform Transformer.
Performs the Haar wavelet transformation on a time series.
Parameters
----------
num_levels : int, number of levels to perform the Haar wavelet
transformation.
"""
_tags = {
"authors": "vnicholson1",
"scitype:transform-input": "Series",
# what is the scitype of X: Series, or Panel
"scitype:transform-output": "Series",
# what scitype is returned: Primitives, Series, Panel
"scitype:instancewise": False, # is this an instance-wise transform?
"X_inner_mtype": "nested_univ", # which mtypes do _fit/_predict support for X?
"y_inner_mtype": "None", # which mtypes do _fit/_predict support for X?
"fit_is_empty": True,
}
def __init__(self, num_levels=3):
self.num_levels = num_levels
super().__init__()
def _transform(self, X, y=None):
"""Transform X and return a transformed version.
private _transform containing core logic, called from transform
Parameters
----------
X : nested pandas DataFrame of shape [n_instances, n_features]
each cell of X must contain pandas.Series
Data to fit transform to
y : ignored argument for interface compatibility
Additional data, e.g., labels for transformation
Returns
-------
Xt : nested pandas DataFrame of shape [n_instances, n_features]
each cell of Xt contains pandas.Series
transformed version of X
"""
self._check_parameters()
# Get information about the dataframe
col_names = X.columns
Xt = pd.DataFrame()
for x in col_names:
# Convert one of the columns in the dataframe to numpy array
arr = convert(
pd.DataFrame(X[x]),
from_type="nested_univ",
to_type="numpyflat",
as_scitype="Panel",
)
transformedData = self._extract_wavelet_coefficients(arr)
# Convert to a numpy array
transformedData = np.asarray(transformedData)
# Add it to the dataframe
colToAdd = []
for i in range(len(transformedData)):
inst = transformedData[i]
colToAdd.append(pd.Series(inst))
Xt[x] = colToAdd
return Xt
def _extract_wavelet_coefficients(self, data):
"""Extract wavelet coefficients of a 2d array of time series.
The coefficients correspond to the wavelet coefficients from levels 1 to
num_levels followed by the approximation coefficients of the highest level.
"""
num_levels = self.num_levels
res = []
for x in data:
if num_levels == 0:
res.append(x)
else:
coeffs = []
current = x
approx = None
for _ in range(num_levels):
approx = self._get_approx_coefficients(current)
wav_coeffs = self._get_wavelet_coefficients(current)
current = approx
wav_coeffs.reverse()
coeffs.extend(wav_coeffs)
approx.reverse()
coeffs.extend(approx)
coeffs.reverse()
res.append(coeffs)
return res
def _check_parameters(self):
"""Check the values of parameters passed to DWT.
Throws
------
ValueError or TypeError if a parameters input is invalid.
"""
if isinstance(self.num_levels, int):
if self.num_levels <= -1:
raise ValueError("num_levels must have the value" + "of at least 0")
else:
raise TypeError(
"num_levels must be an 'int'. Found"
+ "'"
+ type(self.num_levels).__name__
+ "' instead."
)
def _get_approx_coefficients(self, arr):
"""Get the approximate coefficients at a given level."""
new = []
if len(arr) == 1:
return [arr[0]]
for x in range(math.floor(len(arr) / 2)):
new.append((arr[2 * x] + arr[2 * x + 1]) / math.sqrt(2))
return new
def _get_wavelet_coefficients(self, arr):
"""Get the wavelet coefficients at a given level."""
new = []
# if length is 1, just return the list back
if len(arr) == 1:
return [arr[0]]
for x in range(math.floor(len(arr) / 2)):
new.append((arr[2 * x] - arr[2 * x + 1]) / math.sqrt(2))
return new