-
Notifications
You must be signed in to change notification settings - Fork 44
/
covariance_aggregations.py
51 lines (34 loc) · 1.53 KB
/
covariance_aggregations.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from pysparkling.sql.expressions.aggregate.aggregations import Aggregation
class CovarianceStatAggregation(Aggregation):
def __init__(self, column1, column2):
# Top level import would cause cyclic dependencies
# pylint: disable=import-outside-toplevel
from pysparkling.stat_counter import CovarianceCounter
super(CovarianceStatAggregation, self).__init__(column1, column2)
self.column1 = column1
self.column2 = column2
self.stat_helper = CovarianceCounter(method="pearson")
def merge(self, row, schema):
self.stat_helper.add(row.eval(self.column1, schema), row.eval(self.column2, schema))
def mergeStats(self, other, schema):
self.stat_helper.merge(other)
def eval(self, row, schema):
raise NotImplementedError
def __str__(self):
raise NotImplementedError
class Corr(CovarianceStatAggregation):
def eval(self, row, schema):
return self.stat_helper.pearson_correlation
def __str__(self):
return "corr({0}, {1})".format(self.column1, self.column2)
class CovarSamp(CovarianceStatAggregation):
def eval(self, row, schema):
return self.stat_helper.covar_samp
def __str__(self):
return "covar_samp({0}, {1})".format(self.column1, self.column2)
class CovarPop(CovarianceStatAggregation):
def eval(self, row, schema):
return self.stat_helper.covar_pop
def __str__(self):
return "covar_pop({0}, {1})".format(self.column1, self.column2)
__all__ = ["Corr", "CovarSamp", "CovarPop"]