/
arima.pyx
1205 lines (1010 loc) · 47.7 KB
/
arima.pyx
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#
# Copyright (c) 2019-2023, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# distutils: language = c++
from cuml.internals.safe_imports import (
cpu_only_import,
gpu_only_import_from,
null_decorator
)
np = cpu_only_import('numpy')
nvtx_annotate = gpu_only_import_from("nvtx", "annotate", alt=null_decorator)
from libc.stdint cimport uintptr_t
from libcpp cimport bool
from libcpp.vector cimport vector
from typing import Tuple, Dict, Mapping, Optional, Union
import cuml.internals
from cuml.internals.array import CumlArray
from cuml.common.array_descriptor import CumlArrayDescriptor
from cuml.internals.base import Base
from pylibraft.common.handle cimport handle_t
from cuml.tsa.batched_lbfgs import batched_fmin_lbfgs_b
import cuml.internals.logger as logger
from cuml.common import has_scipy
from cuml.internals.input_utils import input_to_cuml_array
from cuml.internals import _deprecate_pos_args
cdef extern from "cuml/tsa/arima_common.h" namespace "ML":
cdef cppclass ARIMAParams[DataT]:
DataT* mu
DataT* beta
DataT* ar
DataT* ma
DataT* sar
DataT* sma
DataT* sigma2
cdef cppclass ARIMAMemory[DataT]:
ARIMAMemory(const ARIMAOrder& order, int batch_size, int n_obs,
char* in_buf)
@staticmethod
size_t compute_size(const ARIMAOrder& order, int batch_size, int n_obs)
cdef extern from "cuml/tsa/batched_arima.hpp" namespace "ML":
ctypedef enum LoglikeMethod: CSS, MLE
void cpp_pack "pack" (
handle_t& handle, const ARIMAParams[double]& params,
const ARIMAOrder& order, int batch_size, double* param_vec)
void cpp_unpack "unpack" (
handle_t& handle, ARIMAParams[double]& params,
const ARIMAOrder& order, int batch_size, const double* param_vec)
bool detect_missing(
handle_t& handle, const double* d_y, int n_elem)
void batched_diff(
handle_t& handle, double* d_y_diff, const double* d_y, int batch_size,
int n_obs, const ARIMAOrder& order)
void batched_loglike(
handle_t& handle, const ARIMAMemory[double]& arima_mem,
const double* y, const double* d_exog, int batch_size, int nobs,
const ARIMAOrder& order, const double* params, double* loglike,
bool trans, bool host_loglike, LoglikeMethod method, int truncate)
void batched_loglike(
handle_t& handle, const ARIMAMemory[double]& arima_mem,
const double* y, const double* d_exog, int batch_size, int n_obs,
const ARIMAOrder& order, const ARIMAParams[double]& params,
double* loglike, bool trans, bool host_loglike, LoglikeMethod method,
int truncate)
void batched_loglike_grad(
handle_t& handle, const ARIMAMemory[double]& arima_mem,
const double* d_y, const double* d_exog, int batch_size, int nobs,
const ARIMAOrder& order, const double* d_x, double* d_grad, double h,
bool trans, LoglikeMethod method, int truncate)
void cpp_predict "predict" (
handle_t& handle, const ARIMAMemory[double]& arima_mem,
const double* d_y, const double* d_exog, const double* d_exog_fut,
int batch_size, int nobs, int start, int end, const ARIMAOrder& order,
const ARIMAParams[double]& params, double* d_y_p, bool pre_diff,
double level, double* d_lower, double* d_upper)
void information_criterion(
handle_t& handle, const ARIMAMemory[double]& arima_mem,
const double* d_y, const double* d_exog, int batch_size, int nobs,
const ARIMAOrder& order, const ARIMAParams[double]& params,
double* ic, int ic_type)
void estimate_x0(
handle_t& handle, ARIMAParams[double]& params, const double* d_y,
const double* d_exog, int batch_size, int nobs,
const ARIMAOrder& order, bool missing)
cdef extern from "cuml/tsa/batched_kalman.hpp" namespace "ML":
void batched_jones_transform(
handle_t& handle, ARIMAMemory[double]& arima_mem,
const ARIMAOrder& order, int batchSize, bool isInv,
const double* h_params, double* h_Tparams)
cdef class ARIMAParamsWrapper:
"""A wrapper class for ARIMAParams"""
cdef ARIMAParams[double] params
def __cinit__(self, model):
cdef ARIMAOrder order = model.order
cdef uintptr_t d_mu_ptr = \
model.mu_.ptr if order.k else <uintptr_t> NULL
cdef uintptr_t d_beta_ptr = \
model.beta_.ptr if order.n_exog else <uintptr_t> NULL
cdef uintptr_t d_ar_ptr = \
model.ar_.ptr if order.p else <uintptr_t> NULL
cdef uintptr_t d_ma_ptr = \
model.ma_.ptr if order.q else <uintptr_t> NULL
cdef uintptr_t d_sar_ptr = \
model.sar_.ptr if order.P else <uintptr_t> NULL
cdef uintptr_t d_sma_ptr = \
model.sma_.ptr if order.Q else <uintptr_t> NULL
cdef uintptr_t d_sigma2_ptr = <uintptr_t> model.sigma2_.ptr
self.params.mu = <double*> d_mu_ptr
self.params.beta = <double*> d_beta_ptr
self.params.ar = <double*> d_ar_ptr
self.params.ma = <double*> d_ma_ptr
self.params.sar = <double*> d_sar_ptr
self.params.sma = <double*> d_sma_ptr
self.params.sigma2 = <double*> d_sigma2_ptr
class ARIMA(Base):
"""
Implements a batched ARIMA model for in- and out-of-sample
time-series prediction, with support for seasonality (SARIMA)
ARIMA stands for Auto-Regressive Integrated Moving Average.
See https://en.wikipedia.org/wiki/Autoregressive_integrated_moving_average
This class can fit an ARIMA(p,d,q) or ARIMA(p,d,q)(P,D,Q)_s model to a
batch of time series of the same length (or various lengths, using missing
values at the start for padding).
The implementation is designed to give the best performance when using
large batches of time series.
Parameters
----------
endog : dataframe or array-like (device or host)
Endogenous variable, assumed to have each time series in columns.
Acceptable formats: cuDF DataFrame, cuDF Series, NumPy ndarray,
Numba device ndarray, cuda array interface compliant array like CuPy.
Missing values are accepted, represented by NaN.
order : Tuple[int, int, int] (default=(1,1,1))
The ARIMA order (p, d, q) of the model
seasonal_order : Tuple[int, int, int, int] (default=(0,0,0,0))
The seasonal ARIMA order (P, D, Q, s) of the model
exog : dataframe or array-like (device or host) (default=None)
Exogenous variables, assumed to have each time series in columns,
such that variables associated with a same batch member are adjacent
(number of columns: n_exog * batch_size)
Acceptable formats: cuDF DataFrame, cuDF Series, NumPy ndarray,
Numba device ndarray, cuda array interface compliant array like CuPy.
Missing values are not supported.
fit_intercept : bool or int (default = True)
Whether to include a constant trend mu in the model
simple_differencing : bool or int (default = True)
If True, the data is differenced before being passed to the Kalman
filter. If False, differencing is part of the state-space model.
In some cases this setting can be ignored: computing forecasts with
confidence intervals will force it to False ; fitting with the CSS
method will force it to True.
Note: that forecasts are always for the original series, whereas
statsmodels computes forecasts for the differenced series when
simple_differencing is True.
handle : cuml.Handle
Specifies the cuml.handle that holds internal CUDA state for
computations in this model. Most importantly, this specifies the CUDA
stream that will be used for the model's computations, so users can
run different models concurrently in different streams by creating
handles in several streams.
If it is None, a new one is created.
verbose : int or boolean, default=False
Sets logging level. It must be one of `cuml.common.logger.level_*`.
See :ref:`verbosity-levels` for more info.
output_type : {'input', 'array', 'dataframe', 'series', 'df_obj', \
'numba', 'cupy', 'numpy', 'cudf', 'pandas'}, default=None
Return results and set estimator attributes to the indicated output
type. If None, the output type set at the module level
(`cuml.global_settings.output_type`) will be used. See
:ref:`output-data-type-configuration` for more info.
convert_dtype : boolean
When set to True, the model will automatically convert the inputs to
np.float64.
Attributes
----------
order : ARIMAOrder
The ARIMA order of the model (p, d, q, P, D, Q, s, k, n_exog)
d_y : device array
Time series data on device
n_obs : int
Number of observations
batch_size : int
Number of time series in the batch
dtype : numpy.dtype
Floating-point type of the data and parameters
niter : numpy.ndarray
After fitting, contains the number of iterations before convergence
for each time series.
Notes
-----
*Performance:* Let :math:`r=max(p+s*P, q+s*Q+1)`. The device memory used
for most operations is
:math:\
`O(\\mathtt{batch\\_size}*\\mathtt{n\\_obs} + \\mathtt{batch\\_size}*r^2)`.
The execution time is a linear function of `n_obs` and `batch_size`
(if `batch_size` is large), but grows very fast with `r`.
The performance is optimized for very large batch sizes (e.g thousands of
series).
References
----------
This class is heavily influenced by the Python library `statsmodels`,
particularly `statsmodels.tsa.statespace.sarimax.SARIMAX`.
See https://www.statsmodels.org/stable/statespace.html.
Additionally the following book is a useful reference:
"Time Series Analysis by State Space Methods",
J. Durbin, S.J. Koopman, 2nd Edition (2012).
Examples
--------
.. code-block:: python
>>> import cupy as cp
>>> from cuml.tsa.arima import ARIMA
>>> # Create seasonal data with a trend, a seasonal pattern and noise
>>> n_obs = 100
>>> cp.random.seed(12)
>>> x = cp.linspace(0, 1, n_obs)
>>> pattern = cp.array([[0.05, 0.0], [0.07, 0.03],
... [-0.03, 0.05], [0.02, 0.025]])
>>> noise = cp.random.normal(scale=0.01, size=(n_obs, 2))
>>> y = (cp.column_stack((0.5*x, -0.25*x)) + noise
... + cp.tile(pattern, (25, 1)))
>>> # Fit a seasonal ARIMA model
>>> model = ARIMA(y,
... order=(0,1,1),
... seasonal_order=(0,1,1,4),
... fit_intercept=False)
>>> model.fit()
ARIMA(...)
>>> # Forecast
>>> fc = model.forecast(10)
>>> print(fc) # doctest: +SKIP
[[ 0.55204599 -0.25681163]
[ 0.57430705 -0.2262438 ]
[ 0.48120315 -0.20583011]
[ 0.535594 -0.24060046]
[ 0.57207541 -0.26695497]
[ 0.59433647 -0.23638713]
[ 0.50123257 -0.21597344]
[ 0.55562342 -0.25074379]
[ 0.59210483 -0.27709831]
[ 0.61436589 -0.24653047]]
"""
d_y = CumlArrayDescriptor()
# TODO: (MDD) Should this be public? Its not listed in the attributes doc
_d_y_diff = CumlArrayDescriptor()
_temp_mem = CumlArrayDescriptor()
mu_ = CumlArrayDescriptor()
beta_ = CumlArrayDescriptor()
ar_ = CumlArrayDescriptor()
ma_ = CumlArrayDescriptor()
sar_ = CumlArrayDescriptor()
sma_ = CumlArrayDescriptor()
sigma2_ = CumlArrayDescriptor()
@_deprecate_pos_args(version="21.06")
def __init__(self,
endog,
*,
order: Tuple[int, int, int] = (1, 1, 1),
seasonal_order: Tuple[int, int, int, int] = (0, 0, 0, 0),
exog=None,
fit_intercept=True,
simple_differencing=True,
handle=None,
verbose=False,
output_type=None,
convert_dtype=True):
if not has_scipy():
raise RuntimeError("Scipy is needed to run cuML's ARIMA estimator."
" Please install it to enable ARIMA "
"estimation.")
# Initialize base class
super().__init__(handle=handle,
verbose=verbose,
output_type=output_type)
self._set_base_attributes(output_type=endog)
# Check validity of the ARIMA order and seasonal order
p, d, q = order
P, D, Q, s = seasonal_order
if P + D + Q > 0 and s < 2:
raise ValueError("ERROR: Invalid period for seasonal ARIMA: {}"
.format(s))
if d + D > 2:
raise ValueError("ERROR: Invalid order. Required: d+D <= 2")
if s != 0 and (p >= s or q >= s):
raise ValueError("ERROR: Invalid order. Required: s > p, s > q")
if p + q + P + Q + int(fit_intercept) == 0:
raise ValueError("ERROR: Invalid order. At least one parameter"
" among p, q, P, Q and fit_intercept must be"
" non-zero")
if p > 8 or P > 8 or q > 8 or Q > 8:
raise ValueError("ERROR: Invalid order. Required: p,q,P,Q <= 8")
if max(p + s * P, q + s * Q) > 1024:
raise ValueError("ERROR: Invalid order. "
"Required: max(p+s*P, q+s*Q) <= 1024")
# Endogenous variable. Float64 only for now.
self.d_y, self.n_obs, self.batch_size, self.dtype \
= input_to_cuml_array(
endog, check_dtype=np.float64,
convert_to_dtype=(np.float64 if convert_dtype else None))
if self.n_obs < d + s * D + 1:
raise ValueError("ERROR: Number of observations too small for the"
" given order")
# Exogenous variables
if exog is not None:
self.d_exog, n_obs_exog, n_cols_exog, _ \
= input_to_cuml_array(exog, check_dtype=np.float64)
if n_cols_exog % self.batch_size != 0:
raise ValueError("Number of columns in exog is not a multiple"
" of batch_size")
if n_obs_exog != self.n_obs:
raise ValueError("Number of observations mismatch between"
" endog and exog")
n_exog = n_cols_exog // self.batch_size
else:
n_exog = 0
# Set the ARIMA order
cdef ARIMAOrder cpp_order
cpp_order.p, cpp_order.d, cpp_order.q = order
cpp_order.P, cpp_order.D, cpp_order.Q, cpp_order.s = seasonal_order
cpp_order.k = int(fit_intercept)
cpp_order.n_exog = n_exog
self.order = cpp_order
self.simple_differencing = simple_differencing
self._d_y_diff = CumlArray.empty(
(self.n_obs - d - s * D, self.batch_size), self.dtype)
if n_exog > 0:
self._d_exog_diff = CumlArray.empty(
(self.n_obs - d - s * D, self.batch_size * n_exog),
self.dtype)
self.n_obs_diff = self.n_obs - d - D * s
# Allocate temporary storage
temp_mem_size = ARIMAMemory[double].compute_size(
cpp_order, <int> self.batch_size, <int> self.n_obs)
self._temp_mem = CumlArray.empty(temp_mem_size, np.byte)
self._initial_calc()
@cuml.internals.api_base_return_any_skipall
def _initial_calc(self):
"""
This separates the initial calculation from the initialization to make
the CumlArrayDescriptors work
"""
cdef uintptr_t d_y_ptr = self.d_y.ptr
cdef uintptr_t d_y_diff_ptr = self._d_y_diff.ptr
cdef uintptr_t d_exog_ptr
cdef uintptr_t d_exog_diff_ptr
cdef handle_t* handle_ = <handle_t*><size_t>self.handle.getHandle()
cdef ARIMAOrder cpp_order_diff = self.order
# Detect missing observations
self.missing = detect_missing(handle_[0], <double*> d_y_ptr,
<int> self.batch_size * self.n_obs)
if self.missing and self.simple_differencing:
logger.warn("Missing observations detected."
" Forcing simple_differencing=False")
self.simple_differencing = False
if self.simple_differencing:
# Compute the differenced series
batched_diff(handle_[0], <double*> d_y_diff_ptr, <double*> d_y_ptr,
<int> self.batch_size, <int> self.n_obs, self.order)
# Create a version of the order for the differenced series
cpp_order_diff.d = 0
cpp_order_diff.D = 0
self.order_diff = cpp_order_diff
if cpp_order_diff.n_exog > 0:
d_exog_ptr = self.d_exog.ptr
d_exog_diff_ptr = self._d_exog_diff.ptr
batched_diff(handle_[0], <double*> d_exog_diff_ptr,
<double*> d_exog_ptr,
<int> self.batch_size * cpp_order_diff.n_exog,
<int> self.n_obs, self.order)
else:
self.order_diff = None
def __str__(self):
cdef ARIMAOrder order = self.order
intercept_str = 'c' if order.k else 'n'
if order.s:
return "ARIMA({},{},{})({},{},{})_{} ({}) - {} series".format(
order.p, order.d, order.q, order.P, order.D, order.Q, order.s,
intercept_str, self.batch_size)
else:
return "ARIMA({},{},{}) ({}) - {} series".format(
order.p, order.d, order.q, intercept_str, self.batch_size)
@nvtx_annotate(message="tsa.arima.ARIMA._ic", domain="cuml_python")
@cuml.internals.api_base_return_any_skipall
def _ic(self, ic_type: str):
"""Wrapper around C++ information_criterion
"""
cdef handle_t* handle_ = <handle_t*><size_t>self.handle.getHandle()
cdef ARIMAOrder order = self.order
cdef ARIMAOrder order_kf = \
self.order_diff if self.simple_differencing else self.order
cdef ARIMAParams[double] cpp_params = ARIMAParamsWrapper(self).params
ic = CumlArray.empty(self.batch_size, self.dtype)
cdef uintptr_t d_ic_ptr = ic.ptr
cdef uintptr_t d_y_kf_ptr = \
self._d_y_diff.ptr if self.simple_differencing else self.d_y.ptr
cdef uintptr_t d_exog_kf_ptr = <uintptr_t> NULL
if order.n_exog:
d_exog_kf_ptr = (self._d_exog_diff.ptr if self.simple_differencing
else self.d_exog.ptr)
n_obs_kf = (self.n_obs_diff if self.simple_differencing
else self.n_obs)
ic_name_to_number = {"aic": 0, "aicc": 1, "bic": 2}
cdef int ic_type_id
try:
ic_type_id = ic_name_to_number[ic_type.lower()]
except KeyError as e:
raise NotImplementedError("IC type '{}' unknown".format(ic_type))
cdef uintptr_t d_temp_mem = self._temp_mem.ptr
arima_mem_ptr = new ARIMAMemory[double](
order, <int> self.batch_size, <int> self.n_obs,
<char*> d_temp_mem)
information_criterion(handle_[0], arima_mem_ptr[0],
<double*> d_y_kf_ptr, <double*> d_exog_kf_ptr,
<int> self.batch_size, <int> n_obs_kf, order_kf,
cpp_params, <double*> d_ic_ptr,
<int> ic_type_id)
del arima_mem_ptr
return ic
@property
def aic(self) -> CumlArray:
"""Akaike Information Criterion"""
return self._ic("aic")
@property
def aicc(self) -> CumlArray:
"""Corrected Akaike Information Criterion"""
return self._ic("aicc")
@property
def bic(self) -> CumlArray:
"""Bayesian Information Criterion"""
return self._ic("bic")
@property
def complexity(self):
"""Model complexity (number of parameters)"""
cdef ARIMAOrder order = self.order
return (order.p + order.P + order.q + order.Q + order.k + order.n_exog
+ 1)
@cuml.internals.api_base_return_generic(input_arg=None)
def get_fit_params(self) -> Dict[str, CumlArray]:
"""Get all the fit parameters. Not to be confused with get_params
Note: pack() can be used to get a compact vector of the parameters
Returns
-------
params: Dict[str, array-like]
A dictionary of parameter names and associated arrays
The key names are in {"mu", "ar", "ma", "sar", "sma", "sigma2"}
The shape of the arrays are (batch_size,) for mu and sigma2 and
(n, batch_size) for any other type, where n is the corresponding
number of parameters of this type.
"""
cdef ARIMAOrder order = self.order
params = dict()
names = ["mu", "beta", "ar", "ma", "sar", "sma", "sigma2"]
criteria = [order.k, order.n_exog, order.p, order.q, order.P, order.Q,
True]
for i in range(len(names)):
if criteria[i] > 0:
params[names[i]] = getattr(self, "{}_".format(names[i]))
return params
def set_fit_params(self, params: Mapping[str, object]):
"""Set all the fit parameters. Not to be confused with ``set_params``
Note: `unpack()` can be used to load a compact vector of the
parameters
Parameters
----------
params: Mapping[str, array-like]
A dictionary of parameter names and associated arrays
The key names are in {"mu", "ar", "ma", "sar", "sma", "sigma2"}
The shape of the arrays are (batch_size,) for mu and sigma2 and
(n, batch_size) for any other type, where n is the corresponding
number of parameters of this type.
"""
for param_name in ["mu", "beta", "ar", "ma", "sar", "sma", "sigma2"]:
if param_name in params:
array, *_ = input_to_cuml_array(params[param_name],
check_dtype=np.float64)
setattr(self, "{}_".format(param_name), array)
def get_param_names(self):
raise NotImplementedError
def get_param_names(self):
"""
.. warning:: ARIMA is unable to be cloned at this time.
The methods: `get_param_names()`, `get_params` and
`set_params` will raise ``NotImplementedError``
"""
raise NotImplementedError("ARIMA is unable to be cloned via "
"`get_params` and `set_params`.")
def get_params(self, deep=True):
"""
.. warning:: ARIMA is unable to be cloned at this time.
The methods: `get_param_names()`, `get_params` and
`set_params` will raise ``NotImplementedError``
"""
raise NotImplementedError("ARIMA is unable to be cloned via "
"`get_params` and `set_params`.")
def set_params(self, **params):
"""
.. warning:: ARIMA is unable to be cloned at this time.
The methods: `get_param_names()`, `get_params` and
`set_params` will raise ``NotImplementedError``
"""
raise NotImplementedError("ARIMA is unable to be cloned via "
"`get_params` and `set_params`.")
@cuml.internals.api_base_return_generic(input_arg=None)
def predict(
self,
start=0,
end=None,
level=None,
exog=None,
) -> Union[CumlArray, Tuple[CumlArray, CumlArray, CumlArray]]:
"""Compute in-sample and/or out-of-sample prediction for each series
Parameters
----------
start : int (default = 0)
Index where to start the predictions (0 <= start <= num_samples)
end : int (default = None)
Index where to end the predictions, excluded (end > start), or
``None`` to predict until the last observation
level : float or None (default = None)
Confidence level for prediction intervals, or None to return only
the point forecasts. ``0 < level < 1``
exog : dataframe or array-like (device or host)
Future values for exogenous variables. Assumed to have each time
series in columns, such that variables associated with a same
batch member are adjacent.
Shape = (end - n_obs, n_exog * batch_size)
Returns
-------
y_p : array-like (device)
Predictions. Shape = (end - start, batch_size)
lower: array-like (device) (optional)
Lower limit of the prediction interval if ``level != None``
Shape = (end - start, batch_size)
upper: array-like (device) (optional)
Upper limit of the prediction interval if ``level != None``
Shape = (end - start, batch_size)
Examples
--------
.. code-block:: python
from cuml.tsa.arima import ARIMA
model = ARIMA(ys, order=(1,1,1))
model.fit()
y_pred = model.predict()
"""
cdef ARIMAOrder order = self.order
cdef ARIMAParams[double] cpp_params = ARIMAParamsWrapper(self).params
if start < 0:
raise ValueError("ERROR(`predict`): start < 0")
elif start > self.n_obs:
raise ValueError("ERROR(`predict`): There can't be a gap between"
" the data and the prediction")
elif end <= start:
raise ValueError("ERROR(`predict`): end <= start")
elif self.simple_differencing and start < order.d + order.D * order.s:
logger.warn("Predictions before {} are undefined when using"
" simple_differencing=True, will be set to NaN"
.format(order.d + order.D * order.s))
if level is not None:
if level <= 0 or level >= 1:
raise ValueError("ERROR: Invalid confidence level: {}"
.format(level))
elif level > 0 and start < self.n_obs:
raise ValueError("ERROR: Prediction intervals can only be"
" computed for out-of-sample predictions")
if end is None:
end = self.n_obs
if order.n_exog > 0 and end > self.n_obs and exog is None:
raise ValueError("The model was fit with a regression component,"
" so future values must be provided via `exog`")
elif order.n_exog == 0 and exog is not None:
raise ValueError("A value was given for `exog` but the model was"
" fit without any regression component")
elif end <= self.n_obs and exog is not None:
raise ValueError("A value was given for `exog` but only in-sample"
" predictions were requested")
cdef handle_t* handle_ = <handle_t*><size_t>self.handle.getHandle()
predict_size = end - start
# Future values of the exogenous variables
cdef uintptr_t d_exog_fut_ptr = <uintptr_t> NULL
if order.n_exog and end > self.n_obs:
d_exog_fut, n_obs_fut, n_cols_fut, _ \
= input_to_cuml_array(exog, check_dtype=np.float64)
if n_obs_fut != end - self.n_obs:
raise ValueError(
"Dimensions mismatch: `exog` should contain {}"
" observations per column".format(end - self.n_obs))
elif n_cols_fut != self.batch_size * order.n_exog:
raise ValueError(
"Dimensions mismatch: `exog` should have {} columns"
.format(self.batch_size * order.n_exog))
d_exog_fut_ptr = d_exog_fut.ptr
# allocate predictions and intervals device memory
cdef uintptr_t d_y_p_ptr = <uintptr_t> NULL
cdef uintptr_t d_lower_ptr = <uintptr_t> NULL
cdef uintptr_t d_upper_ptr = <uintptr_t> NULL
d_y_p = CumlArray.empty((predict_size, self.batch_size),
dtype=np.float64, order="F")
d_y_p_ptr = d_y_p.ptr
if level is not None:
d_lower = CumlArray.empty((predict_size, self.batch_size),
dtype=np.float64, order="F")
d_upper = CumlArray.empty((predict_size, self.batch_size),
dtype=np.float64, order="F")
d_lower_ptr = d_lower.ptr
d_upper_ptr = d_upper.ptr
cdef uintptr_t d_y_ptr = self.d_y.ptr
cdef uintptr_t d_exog_ptr = <uintptr_t> NULL
if order.n_exog:
d_exog_ptr = self.d_exog.ptr
cdef uintptr_t d_temp_mem = self._temp_mem.ptr
arima_mem_ptr = new ARIMAMemory[double](
order, <int> self.batch_size, <int> self.n_obs,
<char*> d_temp_mem)
cpp_predict(handle_[0], arima_mem_ptr[0], <double*>d_y_ptr,
<double*>d_exog_ptr, <double*>d_exog_fut_ptr,
<int> self.batch_size, <int> self.n_obs, <int> start,
<int> end, order, cpp_params, <double*>d_y_p_ptr,
<bool> self.simple_differencing,
<double> (0 if level is None else level),
<double*> d_lower_ptr, <double*> d_upper_ptr)
del arima_mem_ptr
if level is None:
return d_y_p
else:
return (d_y_p,
d_lower,
d_upper)
@nvtx_annotate(message="tsa.arima.ARIMA.forecast", domain="cuml_python")
@cuml.internals.api_base_return_generic_skipall
def forecast(
self,
nsteps: int,
level=None,
exog=None
) -> Union[CumlArray, Tuple[CumlArray, CumlArray, CumlArray]]:
"""Forecast the given model `nsteps` into the future.
Parameters
----------
nsteps : int
The number of steps to forecast beyond end of the given series
level : float or None (default = None)
Confidence level for prediction intervals, or None to return only
the point forecasts. 0 < level < 1
exog : dataframe or array-like (device or host) (default=None)
Future values for exogenous variables. Assumed to have each time
series in columns, such that variables associated with a same
batch member are adjacent.
Shape = (nsteps, n_exog * batch_size)
Returns
-------
y_fc : array-like
Forecasts. Shape = (nsteps, batch_size)
lower : array-like (device) (optional)
Lower limit of the prediction interval if level != None
Shape = (end - start, batch_size)
upper : array-like (device) (optional)
Upper limit of the prediction interval if level != None
Shape = (end - start, batch_size)
Examples
--------
.. code-block:: python
from cuml.tsa.arima import ARIMA
...
model = ARIMA(ys, order=(1,1,1))
model.fit()
y_fc = model.forecast(10)
"""
return self.predict(self.n_obs, self.n_obs + nsteps, level, exog)
@cuml.internals.api_base_return_any_skipall
def _create_arrays(self):
"""Create the parameter arrays if non-existing"""
cdef ARIMAOrder order = self.order
if order.k and not hasattr(self, "mu_"):
self.mu_ = CumlArray.empty(self.batch_size, np.float64)
if order.n_exog and not hasattr(self, "beta_"):
self.beta_ = CumlArray.empty((order.n_exog, self.batch_size),
np.float64)
if order.p and not hasattr(self, "ar_"):
self.ar_ = CumlArray.empty((order.p, self.batch_size),
np.float64)
if order.q and not hasattr(self, "ma_"):
self.ma_ = CumlArray.empty((order.q, self.batch_size),
np.float64)
if order.P and not hasattr(self, "sar_"):
self.sar_ = CumlArray.empty((order.P, self.batch_size),
np.float64)
if order.Q and not hasattr(self, "sma_"):
self.sma_ = CumlArray.empty((order.Q, self.batch_size),
np.float64)
if not hasattr(self, "sigma2_"):
self.sigma2_ = CumlArray.empty(self.batch_size, np.float64)
@nvtx_annotate(message="tsa.arima.ARIMA._estimate_x0",
domain="cuml_python")
@cuml.internals.api_base_return_any_skipall
def _estimate_x0(self):
"""Internal method. Estimate initial parameters of the model.
"""
self._create_arrays()
cdef ARIMAOrder order = self.order
cdef ARIMAParams[double] cpp_params = ARIMAParamsWrapper(self).params
cdef uintptr_t d_y_ptr = self.d_y.ptr
cdef uintptr_t d_exog_ptr = <uintptr_t> NULL
if order.n_exog:
d_exog_ptr = self.d_exog.ptr
cdef handle_t* handle_ = <handle_t*><size_t>self.handle.getHandle()
# Call C++ function
estimate_x0(handle_[0], cpp_params, <double*> d_y_ptr,
<double*> d_exog_ptr, <int> self.batch_size,
<int> self.n_obs, order, <bool> self.missing)
@cuml.internals.api_base_return_any_skipall
def fit(self,
start_params: Optional[Mapping[str, object]] = None,
opt_disp: int = -1,
h: float = 1e-8,
maxiter: int = 1000,
method="ml",
truncate: int = 0) -> "ARIMA":
r"""Fit the ARIMA model to each time series.
Parameters
----------
start_params : Mapping[str, array-like] (optional)
A mapping (e.g dictionary) of parameter names and associated arrays
The key names are in {"mu", "ar", "ma", "sar", "sma", "sigma2"}
The shape of the arrays are (batch_size,) for mu and sigma2
parameters and (n, batch_size) for any other type, where n is the
corresponding number of parameters of this type.
Pass None for automatic estimation (recommended)
opt_disp : int
Fit diagnostic level (for L-BFGS solver):
* `-1` for no output (default)
* `0<n<100` for output every `n` steps
* `n>100` for more detailed output
h : float (default=1e-8)
Finite-differencing step size. The gradient is computed using
forward finite differencing:
:math:`g = \frac{f(x + \mathtt{h}) - f(x)}{\mathtt{h}} + O(\mathtt{h})`
maxiter : int (default=1000)
Maximum number of iterations of L-BFGS-B
method : str (default="ml")
Estimation method - "css", "css-ml" or "ml".
CSS uses a sum-of-squares approximation.
ML estimates the log-likelihood with statespace methods.
CSS-ML starts with CSS and refines with ML.
truncate : int (default=0)
When using CSS, start the sum of squares after a given number of
observations
""" # noqa
def fit_helper(x_in, fit_method):
def f(x: np.ndarray) -> np.ndarray:
"""The (batched) energy functional returning the negative
log-likelihood (foreach series)."""
# Recall: We maximize LL by minimizing -LL
n_llf = -self._loglike(x, True, fit_method, truncate)
return n_llf / (self.n_obs - 1)
# Optimized finite differencing gradient for batches
def gf(x) -> np.ndarray:
"""The gradient of the (batched) energy functional."""
# Recall: We maximize LL by minimizing -LL
n_gllf = -self._loglike_grad(x, h, True, fit_method, truncate)
return n_gllf / (self.n_obs - 1)
# Check initial parameter sanity
if ((np.isnan(x_in).any()) or (np.isinf(x_in).any())):
raise FloatingPointError(
"Initial parameter vector x has NaN or Inf.")
# Optimize parameters by minimizing log likelihood.
x_out, niter, flags = batched_fmin_lbfgs_b(
f, x_in, self.batch_size, gf, iprint=opt_disp, factr=1000,
maxiter=maxiter)
# Handle non-zero flags with Warning
if (flags != 0).any():
logger.warn("fit: Some batch members had optimizer problems")
return x_out, niter
if start_params is None:
self._estimate_x0()
else:
self.set_fit_params(start_params)
x0 = self._batched_transform(self.pack(), True)
method = method.lower()
if method not in {"css", "css-ml", "ml"}:
raise ValueError("Unknown method: {}".format(method))
if self.missing and (method == "css" or method == "css-ml"):
logger.warn("Missing observations detected."
" Forcing method=\"ml\"")
method = "ml"
if method == "css" or method == "css-ml":
x, self.niter = fit_helper(x0, "css")
if method == "css-ml" or method == "ml":
x, niter = fit_helper(x if method == "css-ml" else x0, "ml")
self.niter = (self.niter + niter) if method == "css-ml" else niter
self.unpack(self._batched_transform(x))
return self
@nvtx_annotate(message="tsa.arima.ARIMA._loglike", domain="cuml_python")
@cuml.internals.api_base_return_any_skipall
def _loglike(self, x, trans=True, method="ml", truncate=0):
"""Compute the batched log-likelihood for the given parameters.
Parameters
----------
x : array-like
Packed parameter array, grouped by series
trans : bool (default=True)
Should the Jones' transform be applied?
Note: The parameters from a fit model are already transformed.
method : str (default="ml")
Estimation method: "css" for sum-of-squares, "ml" for
an estimation with statespace methods
truncate : int (default=0)
When using CSS, start the sum of squares after a given number of
observations
Returns
-------
loglike : numpy.ndarray
Batched log-likelihood. Shape: (batch_size,)
"""
cdef vector[double] vec_loglike
vec_loglike.resize(self.batch_size)
cdef LoglikeMethod ll_method = CSS if method == "css" else MLE
diff = ll_method != MLE or self.simple_differencing
cdef ARIMAOrder order = self.order
cdef ARIMAOrder order_kf = self.order_diff if diff else self.order
d_x_array, *_ = \
input_to_cuml_array(x, check_dtype=np.float64, order='C')
cdef uintptr_t d_x_ptr = d_x_array.ptr
cdef uintptr_t d_y_kf_ptr = \
self._d_y_diff.ptr if diff else self.d_y.ptr
cdef uintptr_t d_exog_kf_ptr = <uintptr_t> NULL
if order.n_exog:
d_exog_kf_ptr = self._d_exog_diff.ptr if diff else self.d_exog.ptr
cdef handle_t* handle_ = <handle_t*><size_t>self.handle.getHandle()
n_obs_kf = (self.n_obs_diff if diff else self.n_obs)
cdef uintptr_t d_temp_mem = self._temp_mem.ptr
arima_mem_ptr = new ARIMAMemory[double](
order, <int> self.batch_size, <int> self.n_obs,
<char*> d_temp_mem)
batched_loglike(handle_[0], arima_mem_ptr[0], <double*> d_y_kf_ptr,
<double*> d_exog_kf_ptr, <int> self.batch_size,
<int> n_obs_kf, order_kf, <double*> d_x_ptr,
<double*> vec_loglike.data(), <bool> trans,
<bool> True, ll_method, <int> truncate)
del arima_mem_ptr
return np.array(vec_loglike, dtype=np.float64)