/
naive_bayes.py
1959 lines (1676 loc) · 67.2 KB
/
naive_bayes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#
# Copyright (c) 2020-2023, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from cuml.common.kernel_utils import cuda_kernel_factory
from cuml.internals.input_utils import input_to_cuml_array, input_to_cupy_array
from cuml.prims.array import binarize
from cuml.prims.label import invert_labels
from cuml.prims.label import check_labels
from cuml.prims.label import make_monotonic
from cuml.internals.import_utils import has_scipy
from cuml.common.doc_utils import generate_docstring
from cuml.internals.mixins import ClassifierMixin
from cuml.internals.base import Base
from cuml.common.array_descriptor import CumlArrayDescriptor
from cuml.common import CumlArray
import math
import warnings
from cuml.internals.safe_imports import (
gpu_only_import,
gpu_only_import_from,
null_decorator,
)
nvtx_annotate = gpu_only_import_from("nvtx", "annotate", alt=null_decorator)
cp = gpu_only_import("cupy")
cupyx = gpu_only_import("cupyx")
def count_features_coo_kernel(float_dtype, int_dtype):
"""
A simple reduction kernel that takes in a sparse (COO) array
of features and computes the sum (or sum squared) for each class
label
"""
kernel_str = r"""({0} *out,
int *rows, int *cols,
{0} *vals, int nnz,
int n_rows, int n_cols,
{1} *labels,
{0} *weights,
bool has_weights,
int n_classes,
bool square) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if(i >= nnz) return;
int row = rows[i];
int col = cols[i];
{0} val = vals[i];
{1} label = labels[row];
unsigned out_idx = (col * n_classes) + label;
if(has_weights)
val *= weights[i];
if(square) val *= val;
atomicAdd(out + out_idx, val);
}"""
return cuda_kernel_factory(
kernel_str, (float_dtype, int_dtype), "count_features_coo"
)
def count_classes_kernel(float_dtype, int_dtype):
kernel_str = r"""
({0} *out, int n_rows, {1} *labels) {
int row = blockIdx.x * blockDim.x + threadIdx.x;
if(row >= n_rows) return;
{1} label = labels[row];
atomicAdd(out + label, ({0})1);
}"""
return cuda_kernel_factory(
kernel_str, (float_dtype, int_dtype), "count_classes"
)
def count_features_dense_kernel(float_dtype, int_dtype):
kernel_str = r"""
({0} *out,
{0} *in,
int n_rows,
int n_cols,
{1} *labels,
{0} *weights,
bool has_weights,
int n_classes,
bool square,
bool rowMajor,
bool categorical) {
int row = blockIdx.x * blockDim.x + threadIdx.x;
int col = blockIdx.y * blockDim.y + threadIdx.y;
if(row >= n_rows || col >= n_cols) return;
{0} val = !rowMajor ?
in[col * n_rows + row] : in[row * n_cols + col];
{1} label = labels[row];
unsigned out_idx = ((col * n_classes) + label);
if (categorical)
{
out_idx = (val * n_classes * n_cols) + (label * n_cols) + col;
val = 1;
}
if(has_weights)
val *= weights[row];
if(val == 0.0) return;
if(square) val *= val;
atomicAdd(out + out_idx, val);
}"""
return cuda_kernel_factory(
kernel_str, (float_dtype, int_dtype), "count_features_dense"
)
def _convert_x_sparse(X):
X = X.tocoo()
if X.dtype not in [cp.float32, cp.float64]:
raise ValueError(
"Only floating-point dtypes (float32 or "
"float64) are supported for sparse inputs."
)
rows = cp.asarray(X.row, dtype=X.row.dtype)
cols = cp.asarray(X.col, dtype=X.col.dtype)
data = cp.asarray(X.data, dtype=X.data.dtype)
return cupyx.scipy.sparse.coo_matrix((data, (rows, cols)), shape=X.shape)
class _BaseNB(Base, ClassifierMixin):
classes_ = CumlArrayDescriptor()
class_count_ = CumlArrayDescriptor()
feature_count_ = CumlArrayDescriptor()
class_log_prior_ = CumlArrayDescriptor()
feature_log_prob_ = CumlArrayDescriptor()
def __init__(self, *, verbose=False, handle=None, output_type=None):
super(_BaseNB, self).__init__(
verbose=verbose, handle=handle, output_type=output_type
)
def _check_X(self, X):
"""To be overridden in subclasses with the actual checks."""
return X
@generate_docstring(
X="dense_sparse",
return_values={
"name": "y_hat",
"type": "dense",
"description": "Predicted values",
"shape": "(n_rows, 1)",
},
)
def predict(self, X) -> CumlArray:
"""
Perform classification on an array of test vectors X.
"""
if has_scipy():
from scipy.sparse import isspmatrix as scipy_sparse_isspmatrix
else:
from cuml.internals.import_utils import (
dummy_function_always_false as scipy_sparse_isspmatrix,
)
# todo: use a sparse CumlArray style approach when ready
# https://github.com/rapidsai/cuml/issues/2216
if scipy_sparse_isspmatrix(X) or cupyx.scipy.sparse.isspmatrix(X):
X = _convert_x_sparse(X)
index = None
else:
X = input_to_cuml_array(
X, order="K", check_dtype=[cp.float32, cp.float64, cp.int32]
)
index = X.index
# todo: improve index management for cupy based codebases
X = X.array.to_output("cupy")
X = self._check_X(X)
jll = self._joint_log_likelihood(X)
indices = cp.argmax(jll, axis=1).astype(self.classes_.dtype)
y_hat = invert_labels(indices, classes=self.classes_)
y_hat = CumlArray(data=y_hat, index=index)
return y_hat
@generate_docstring(
X="dense_sparse",
return_values={
"name": "C",
"type": "dense",
"description": (
"Returns the log-probability of the samples for each class in "
"the model. The columns correspond to the classes in sorted "
"order, as they appear in the attribute `classes_`."
),
"shape": "(n_rows, 1)",
},
)
def predict_log_proba(self, X) -> CumlArray:
"""
Return log-probability estimates for the test vector X.
"""
if has_scipy():
from scipy.sparse import isspmatrix as scipy_sparse_isspmatrix
else:
from cuml.internals.import_utils import (
dummy_function_always_false as scipy_sparse_isspmatrix,
)
# todo: use a sparse CumlArray style approach when ready
# https://github.com/rapidsai/cuml/issues/2216
if scipy_sparse_isspmatrix(X) or cupyx.scipy.sparse.isspmatrix(X):
X = _convert_x_sparse(X)
index = None
else:
X = input_to_cuml_array(
X, order="K", check_dtype=[cp.float32, cp.float64, cp.int32]
)
index = X.index
# todo: improve index management for cupy based codebases
X = X.array.to_output("cupy")
X = self._check_X(X)
jll = self._joint_log_likelihood(X)
# normalize by P(X) = P(f_1, ..., f_n)
# Compute log(sum(exp()))
# Subtract max in exp to prevent inf
a_max = cp.amax(jll, axis=1, keepdims=True)
exp = cp.exp(jll - a_max)
logsumexp = cp.log(cp.sum(exp, axis=1))
a_max = cp.squeeze(a_max, axis=1)
log_prob_x = a_max + logsumexp
if log_prob_x.ndim < 2:
log_prob_x = log_prob_x.reshape((1, log_prob_x.shape[0]))
result = jll - log_prob_x.T
result = CumlArray(data=result, index=index)
return result
@generate_docstring(
X="dense_sparse",
return_values={
"name": "C",
"type": "dense",
"description": (
"Returns the probability of the samples for each class in the "
"model. The columns correspond to the classes in sorted order,"
" as they appear in the attribute `classes_`."
),
"shape": "(n_rows, 1)",
},
)
def predict_proba(self, X) -> CumlArray:
"""
Return probability estimates for the test vector X.
"""
result = cp.exp(self.predict_log_proba(X))
return result
class GaussianNB(_BaseNB):
"""
Gaussian Naive Bayes (GaussianNB)
Can perform online updates to model parameters via :meth:`partial_fit`.
For details on algorithm used to update feature means and variance online,
see Stanford CS tech report STAN-CS-79-773 by Chan, Golub, and LeVeque:
http://i.stanford.edu/pub/cstr/reports/cs/tr/79/773/CS-TR-79-773.pdf
Parameters
----------
priors : array-like of shape (n_classes,)
Prior probabilities of the classes. If specified the priors are not
adjusted according to the data.
var_smoothing : float, default=1e-9
Portion of the largest variance of all features that is added to
variances for calculation stability.
output_type : {'input', 'array', 'dataframe', 'series', 'df_obj', \
'numba', 'cupy', 'numpy', 'cudf', 'pandas'}, default=None
Return results and set estimator attributes to the indicated output
type. If None, the output type set at the module level
(`cuml.global_settings.output_type`) will be used. See
:ref:`output-data-type-configuration` for more info.
handle : cuml.Handle
Specifies the cuml.handle that holds internal CUDA state for
computations in this model. Most importantly, this specifies the
CUDA stream that will be used for the model's computations, so
users can run different models concurrently in different streams
by creating handles in several streams.
If it is None, a new one is created.
verbose : int or boolean, default=False
Sets logging level. It must be one of `cuml.common.logger.level_*`.
See :ref:`verbosity-levels` for more info.
Examples
--------
.. code-block:: python
>>> import cupy as cp
>>> X = cp.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1],
... [3, 2]], cp.float32)
>>> Y = cp.array([1, 1, 1, 2, 2, 2], cp.float32)
>>> from cuml.naive_bayes import GaussianNB
>>> clf = GaussianNB()
>>> clf.fit(X, Y)
GaussianNB()
>>> print(clf.predict(cp.array([[-0.8, -1]], cp.float32)))
[1]
>>> clf_pf = GaussianNB()
>>> clf_pf.partial_fit(X, Y, cp.unique(Y))
GaussianNB()
>>> print(clf_pf.predict(cp.array([[-0.8, -1]], cp.float32)))
[1]
"""
def __init__(
self,
*,
priors=None,
var_smoothing=1e-9,
output_type=None,
handle=None,
verbose=False,
):
super(GaussianNB, self).__init__(
handle=handle, verbose=verbose, output_type=output_type
)
self.priors = priors
self.var_smoothing = var_smoothing
self.fit_called_ = False
self.classes_ = None
def fit(self, X, y, sample_weight=None) -> "GaussianNB":
"""
Fit Gaussian Naive Bayes classifier according to X, y
Parameters
----------
X : {array-like, cupy sparse matrix} of shape (n_samples, n_features)
Training vectors, where n_samples is the number of samples and
n_features is the number of features.
y : array-like shape (n_samples) Target values.
sample_weight : array-like of shape (n_samples)
Weights applied to individual samples (1. for unweighted).
Currently sample weight is ignored.
"""
return self._partial_fit(
X,
y,
_classes=cp.unique(y),
_refit=True,
sample_weight=sample_weight,
)
@nvtx_annotate(
message="naive_bayes.GaussianNB._partial_fit", domain="cuml_python"
)
def _partial_fit(
self,
X,
y,
_classes=None,
_refit=False,
sample_weight=None,
convert_dtype=True,
) -> "GaussianNB":
if has_scipy():
from scipy.sparse import isspmatrix as scipy_sparse_isspmatrix
else:
from cuml.internals.import_utils import (
dummy_function_always_false as scipy_sparse_isspmatrix,
)
if getattr(self, "classes_") is None and _classes is None:
raise ValueError(
"classes must be passed on the first call " "to partial_fit."
)
if scipy_sparse_isspmatrix(X) or cupyx.scipy.sparse.isspmatrix(X):
X = _convert_x_sparse(X)
else:
X = input_to_cupy_array(
X, order="K", check_dtype=[cp.float32, cp.float64, cp.int32]
).array
expected_y_dtype = (
cp.int32 if X.dtype in [cp.float32, cp.int32] else cp.int64
)
y = input_to_cupy_array(
y,
convert_to_dtype=(expected_y_dtype if convert_dtype else False),
check_dtype=expected_y_dtype,
).array
if _classes is not None:
_classes, *_ = input_to_cuml_array(
_classes,
order="K",
convert_to_dtype=(
expected_y_dtype if convert_dtype else False
),
)
Y, label_classes = make_monotonic(y, classes=_classes, copy=True)
if _refit:
self.classes_ = None
def var_sparse(X, axis=0):
# Compute the variance on dense and sparse matrices
return ((X - X.mean(axis=axis)) ** 2).mean(axis=axis)
self.epsilon_ = self.var_smoothing * var_sparse(X).max()
if not self.fit_called_:
self.fit_called_ = True
# Original labels are stored on the instance
if _classes is not None:
check_labels(Y, _classes.to_output("cupy"))
self.classes_ = _classes
else:
self.classes_ = label_classes
n_features = X.shape[1]
n_classes = len(self.classes_)
self.n_classes_ = n_classes
self.n_features_ = n_features
self.theta_ = cp.zeros((n_classes, n_features))
self.sigma_ = cp.zeros((n_classes, n_features))
self.class_count_ = cp.zeros(n_classes, dtype=X.dtype)
if self.priors is not None:
if len(self.priors) != n_classes:
raise ValueError(
"Number of priors must match number of" " classes."
)
if not cp.isclose(self.priors.sum(), 1):
raise ValueError("The sum of the priors should be 1.")
if (self.priors < 0).any():
raise ValueError("Priors must be non-negative.")
self.class_prior, *_ = input_to_cupy_array(
self.priors, check_dtype=[cp.float32, cp.float64]
)
else:
self.sigma_[:, :] -= self.epsilon_
unique_y = cp.unique(y)
unique_y_in_classes = cp.in1d(unique_y, cp.array(self.classes_))
if not cp.all(unique_y_in_classes):
raise ValueError(
"The target label(s) %s in y do not exist "
"in the initial classes %s"
% (unique_y[~unique_y_in_classes], self.classes_)
)
self.theta_, self.sigma_ = self._update_mean_variance(X, Y)
self.sigma_[:, :] += self.epsilon_
if self.priors is None:
self.class_prior = self.class_count_ / self.class_count_.sum()
return self
def partial_fit(
self, X, y, classes=None, sample_weight=None
) -> "GaussianNB":
"""
Incremental fit on a batch of samples.
This method is expected to be called several times consecutively on
different chunks of a dataset so as to implement out-of-core or online
learning.
This is especially useful when the whole dataset is too big to fit in
memory at once.
This method has some performance overhead hence it is better to call
partial_fit on chunks of data that are as large as possible (as long
as fitting in the memory budget) to hide the overhead.
Parameters
----------
X : {array-like, cupy sparse matrix} of shape (n_samples, n_features)
Training vectors, where n_samples is the number of samples and
n_features is the number of features. A sparse matrix in COO
format is preferred, other formats will go through a conversion
to COO.
y : array-like of shape (n_samples) Target values.
classes : array-like of shape (n_classes)
List of all the classes that can possibly appear in the y
vector. Must be provided at the first call to partial_fit,
can be omitted in subsequent calls.
sample_weight : array-like of shape (n_samples)
Weights applied to individual samples (1. for
unweighted). Currently sample weight is ignored.
Returns
-------
self : object
"""
return self._partial_fit(
X, y, classes, _refit=False, sample_weight=sample_weight
)
def _update_mean_variance(self, X, Y, sample_weight=None):
if sample_weight is None:
sample_weight = cp.zeros(0)
labels_dtype = self.classes_.dtype
mu = self.theta_
var = self.sigma_
early_return = self.class_count_.sum() == 0
n_past = cp.expand_dims(self.class_count_, axis=1).copy()
tpb = 32
n_rows = X.shape[0]
n_cols = X.shape[1]
if X.shape[0] == 0:
return mu, var
# Make sure Y is cp array not CumlArray
Y = cp.asarray(Y)
new_mu = cp.zeros(
(self.n_classes_, self.n_features_), order="F", dtype=X.dtype
)
new_var = cp.zeros(
(self.n_classes_, self.n_features_), order="F", dtype=X.dtype
)
class_counts = cp.zeros(self.n_classes_, order="F", dtype=X.dtype)
if cupyx.scipy.sparse.isspmatrix(X):
X = X.tocoo()
count_features_coo = count_features_coo_kernel(
X.dtype, labels_dtype
)
# Run once for averages
count_features_coo(
(math.ceil(X.nnz / tpb),),
(tpb,),
(
new_mu,
X.row,
X.col,
X.data,
X.nnz,
n_rows,
n_cols,
Y,
sample_weight,
sample_weight.shape[0] > 0,
self.n_classes_,
False,
),
)
# Run again for variance
count_features_coo(
(math.ceil(X.nnz / tpb),),
(tpb,),
(
new_var,
X.row,
X.col,
X.data,
X.nnz,
n_rows,
n_cols,
Y,
sample_weight,
sample_weight.shape[0] > 0,
self.n_classes_,
True,
),
)
else:
count_features_dense = count_features_dense_kernel(
X.dtype, labels_dtype
)
# Run once for averages
count_features_dense(
(math.ceil(n_rows / tpb), math.ceil(n_cols / tpb), 1),
(tpb, tpb, 1),
(
new_mu,
X,
n_rows,
n_cols,
Y,
sample_weight,
sample_weight.shape[0] > 0,
self.n_classes_,
False,
X.flags["C_CONTIGUOUS"],
False,
),
)
# Run again for variance
count_features_dense(
(math.ceil(n_rows / tpb), math.ceil(n_cols / tpb), 1),
(tpb, tpb, 1),
(
new_var,
X,
n_rows,
n_cols,
Y,
sample_weight,
sample_weight.shape[0] > 0,
self.n_classes_,
True,
X.flags["C_CONTIGUOUS"],
False,
),
)
count_classes = count_classes_kernel(X.dtype, labels_dtype)
count_classes(
(math.ceil(n_rows / tpb),), (tpb,), (class_counts, n_rows, Y)
)
self.class_count_ += class_counts
# Avoid any division by zero
class_counts = cp.expand_dims(class_counts, axis=1)
class_counts += cp.finfo(X.dtype).eps
new_mu /= class_counts
# Construct variance from sum squares
new_var = (new_var / class_counts) - new_mu**2
if early_return:
return new_mu, new_var
# Compute (potentially weighted) mean and variance of new datapoints
if sample_weight.shape[0] > 0:
n_new = float(sample_weight.sum())
else:
n_new = class_counts
n_total = n_past + n_new
total_mu = (new_mu * n_new + mu * n_past) / n_total
old_ssd = var * n_past
new_ssd = n_new * new_var
ssd_sum = old_ssd + new_ssd
combined_feature_counts = n_new * n_past / n_total
mean_adj = (mu - new_mu) ** 2
total_ssd = ssd_sum + combined_feature_counts * mean_adj
total_var = total_ssd / n_total
return total_mu, total_var
def _joint_log_likelihood(self, X):
joint_log_likelihood = []
for i in range(len(self.classes_)):
jointi = cp.log(self.class_prior[i])
n_ij = -0.5 * cp.sum(cp.log(2.0 * cp.pi * self.sigma_[i, :]))
centered = (X - self.theta_[i, :]) ** 2
zvals = centered / self.sigma_[i, :]
summed = cp.sum(zvals, axis=1)
n_ij = -(0.5 * summed) + n_ij
joint_log_likelihood.append(jointi + n_ij)
return cp.array(joint_log_likelihood).T
def get_param_names(self):
return super().get_param_names() + ["priors", "var_smoothing"]
class _BaseDiscreteNB(_BaseNB):
def __init__(
self,
*,
alpha=1.0,
fit_prior=True,
class_prior=None,
verbose=False,
handle=None,
output_type=None,
):
super(_BaseDiscreteNB, self).__init__(
verbose=verbose, handle=handle, output_type=output_type
)
if class_prior is not None:
self.class_prior, *_ = input_to_cuml_array(class_prior)
else:
self.class_prior = None
if alpha < 0:
raise ValueError("Smoothing parameter alpha should be >= 0.")
self.alpha = alpha
self.fit_prior = fit_prior
self.fit_called_ = False
self.n_classes_ = 0
self.n_features_ = None
# Needed until Base no longer assumed cumlHandle
self.handle = None
def _check_X_y(self, X, y):
return X, y
def _update_class_log_prior(self, class_prior=None):
if class_prior is not None:
if class_prior.shape[0] != self.n_classes_:
raise ValueError(
"Number of classes must match " "number of priors"
)
self.class_log_prior_ = cp.log(class_prior)
elif self.fit_prior:
log_class_count = cp.log(self.class_count_)
self.class_log_prior_ = log_class_count - cp.log(
self.class_count_.sum()
)
else:
self.class_log_prior_ = cp.full(
self.n_classes_, -math.log(self.n_classes_)
)
def partial_fit(
self, X, y, classes=None, sample_weight=None
) -> "_BaseDiscreteNB":
"""
Incremental fit on a batch of samples.
This method is expected to be called several times consecutively on
different chunks of a dataset so as to implement out-of-core or online
learning.
This is especially useful when the whole dataset is too big to fit in
memory at once.
This method has some performance overhead hence it is better to call
partial_fit on chunks of data that are as large as possible (as long
as fitting in the memory budget) to hide the overhead.
Parameters
----------
X : {array-like, cupy sparse matrix} of shape (n_samples, n_features)
Training vectors, where n_samples is the number of samples and
n_features is the number of features
y : array-like of shape (n_samples) Target values.
classes : array-like of shape (n_classes)
List of all the classes that can possibly appear in the y
vector. Must be provided at the first call to partial_fit,
can be omitted in subsequent calls.
sample_weight : array-like of shape (n_samples)
Weights applied to individual samples (1. for
unweighted). Currently sample weight is ignored.
Returns
-------
self : object
"""
return self._partial_fit(
X, y, sample_weight=sample_weight, _classes=classes
)
@nvtx_annotate(
message="naive_bayes._BaseDiscreteNB._partial_fit",
domain="cuml_python",
)
def _partial_fit(
self, X, y, sample_weight=None, _classes=None, convert_dtype=True
) -> "_BaseDiscreteNB":
if has_scipy():
from scipy.sparse import isspmatrix as scipy_sparse_isspmatrix
else:
from cuml.internals.import_utils import (
dummy_function_always_false as scipy_sparse_isspmatrix,
)
# TODO: use SparseCumlArray
if scipy_sparse_isspmatrix(X) or cupyx.scipy.sparse.isspmatrix(X):
X = _convert_x_sparse(X)
else:
X = input_to_cupy_array(
X, order="K", check_dtype=[cp.float32, cp.float64, cp.int32]
).array
expected_y_dtype = (
cp.int32 if X.dtype in [cp.float32, cp.int32] else cp.int64
)
y = input_to_cupy_array(
y,
convert_to_dtype=(expected_y_dtype if convert_dtype else False),
check_dtype=expected_y_dtype,
).array
if _classes is not None:
_classes, *_ = input_to_cuml_array(
_classes,
order="K",
convert_to_dtype=(
expected_y_dtype if convert_dtype else False
),
)
Y, label_classes = make_monotonic(y, classes=_classes, copy=True)
X, Y = self._check_X_y(X, Y)
if not self.fit_called_:
self.fit_called_ = True
if _classes is not None:
check_labels(Y, _classes.to_output("cupy"))
self.classes_ = _classes
else:
self.classes_ = label_classes
self.n_classes_ = self.classes_.shape[0]
self.n_features_ = X.shape[1]
self._init_counters(self.n_classes_, self.n_features_, X.dtype)
else:
check_labels(Y, self.classes_)
if cupyx.scipy.sparse.isspmatrix(X):
# X is assumed to be a COO here
self._count_sparse(X.row, X.col, X.data, X.shape, Y, self.classes_)
else:
self._count(X, Y, self.classes_)
self._update_feature_log_prob(self.alpha)
self._update_class_log_prior(class_prior=self.class_prior)
return self
def fit(self, X, y, sample_weight=None) -> "_BaseDiscreteNB":
"""
Fit Naive Bayes classifier according to X, y
Parameters
----------
X : {array-like, cupy sparse matrix} of shape (n_samples, n_features)
Training vectors, where n_samples is the number of samples and
n_features is the number of features.
y : array-like shape (n_samples) Target values.
sample_weight : array-like of shape (n_samples)
Weights applied to individual samples (1. for unweighted).
Currently sample weight is ignored.
"""
self.fit_called_ = False
return self.partial_fit(X, y, sample_weight)
def _init_counters(self, n_effective_classes, n_features, dtype):
self.class_count_ = cp.zeros(
n_effective_classes, order="F", dtype=dtype
)
self.feature_count_ = cp.zeros(
(n_effective_classes, n_features), order="F", dtype=dtype
)
def update_log_probs(self):
"""
Updates the log probabilities. This enables lazy update for
applications like distributed Naive Bayes, so that the model
can be updated incrementally without incurring this cost each
time.
"""
self._update_feature_log_prob(self.alpha)
self._update_class_log_prior(class_prior=self.class_prior)
def _count(self, X, Y, classes):
"""
Sum feature counts & class prior counts and add to current model.
Parameters
----------
X : cupy.ndarray or cupyx.scipy.sparse matrix of size
(n_rows, n_features)
Y : cupy.array of monotonic class labels
"""
n_classes = classes.shape[0]
sample_weight = cp.zeros(0)
if X.ndim != 2:
raise ValueError("Input samples should be a 2D array")
if Y.dtype != classes.dtype:
warnings.warn(
"Y dtype does not match classes_ dtype. Y will be "
"converted, which will increase memory consumption"
)
# Make sure Y is a cupy array, not CumlArray
Y = cp.asarray(Y)
counts = cp.zeros(
(n_classes, self.n_features_), order="F", dtype=X.dtype
)
class_c = cp.zeros(n_classes, order="F", dtype=X.dtype)
n_rows = X.shape[0]
n_cols = X.shape[1]
tpb = 32
labels_dtype = classes.dtype
count_features_dense = count_features_dense_kernel(
X.dtype, labels_dtype
)
count_features_dense(
(math.ceil(n_rows / tpb), math.ceil(n_cols / tpb), 1),
(tpb, tpb, 1),
(
counts,
X,
n_rows,
n_cols,
Y,
sample_weight,
sample_weight.shape[0] > 0,
n_classes,
False,
X.flags["C_CONTIGUOUS"],
False,
),
)
tpb = 256
count_classes = count_classes_kernel(X.dtype, labels_dtype)
count_classes((math.ceil(n_rows / tpb),), (tpb,), (class_c, n_rows, Y))
self.feature_count_ += counts
self.class_count_ += class_c
def _count_sparse(
self, x_coo_rows, x_coo_cols, x_coo_data, x_shape, Y, classes
):
"""
Sum feature counts & class prior counts and add to current model.
Parameters
----------
x_coo_rows : cupy.ndarray of size (nnz)
x_coo_cols : cupy.ndarray of size (nnz)
x_coo_data : cupy.ndarray of size (nnz)
Y : cupy.array of monotonic class labels