-
Notifications
You must be signed in to change notification settings - Fork 56
/
_randomvariablelist.py
99 lines (78 loc) · 2.9 KB
/
_randomvariablelist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from typing import Union
import numpy as np
from probnum import randvars
try:
# functools.cached_property is only available in Python >=3.8
from functools import cached_property
except ImportError:
from cached_property import cached_property
class _RandomVariableList(list):
"""List of RandomVariables with convenient access to means, covariances, etc.
Parameters
----------
rv_list :
:obj:`list` of :obj:`RandomVariable`
"""
def __init__(self, rv_list: list):
if not isinstance(rv_list, list):
raise TypeError("RandomVariableList expects a list.")
# If not empty:
if len(rv_list) > 0:
# First element as a proxy for checking all elements
if not isinstance(rv_list[0], randvars.RandomVariable):
raise TypeError(
"RandomVariableList expects RandomVariable elements, but "
+ f"first element has type {type(rv_list[0])}."
)
super().__init__(rv_list)
def __getitem__(self, idx) -> Union[randvars.RandomVariable, "_RandomVariableList"]:
result = super().__getitem__(idx)
# Make sure to wrap the result into a _RandomVariableList if necessary
if isinstance(result, list):
result = _RandomVariableList(result)
return result
@cached_property
def mean(self) -> np.ndarray:
if len(self) == 0:
return np.array([])
return np.stack([rv.mean for rv in self])
@cached_property
def cov(self) -> np.ndarray:
if len(self) == 0:
return np.array([])
return np.stack([rv.cov for rv in self])
@cached_property
def var(self) -> np.ndarray:
if len(self) == 0:
return np.array([])
return np.stack([rv.var for rv in self])
@cached_property
def std(self) -> np.ndarray:
if len(self) == 0:
return np.array([])
return np.stack([rv.std for rv in self])
@property
def shape(self):
first_rv = np.asarray(self[0].mean)
return (len(self),) + first_rv.shape
@cached_property
def mode(self) -> np.ndarray:
if len(self) == 0:
return np.array([])
return np.stack([rv.mode for rv in self])
# For discrete random variables:
@cached_property
def support(self) -> np.ndarray:
if len(self) == 0:
return np.array([])
return np.stack([rv.support for rv in self])
@cached_property
def probabilities(self) -> np.ndarray:
if len(self) == 0:
return np.array([])
return np.stack([rv.probabilities for rv in self])
# Purely for lists of categorical random variables.
def resample(self, rng: np.random.Generator) -> "_RandomVariableList":
if len(self) == 0:
return _RandomVariableList([])
return _RandomVariableList([rv.resample(rng=rng) for rv in self])