In [None]:
%load_ext lab_black

In [2]:
from solike.tests.test_gaussian import toy_data

datalist, cross_cov = toy_data()

In [3]:
from solike.gaussian import MultiGaussianData

multi = MultiGaussianData(datalist, cross_cov)

In [5]:
multi.name

'A + B + C'

In [7]:
import numpy as np
from scipy.linalg import cholesky, LinAlgError
from scipy.stats import multivariate_normal


class GaussianData(object):
    """Named multivariate gaussian data 
    
    For CMB PS data, x will typically be l, and y will be power spectrum.
    """

    def __init__(self, name, x, y, cov):

        self.name = str(name)

        if not (len(x) == len(y) and cov.shape == (len(x), len(x))):
            raise ValueError(
                f"Incompatible shapes! x={x.shape}, y={y.shape}, cov={cov.shape}"
            )

        self.x = x
        self.y = y
        self.cov = cov
        try:
            self.cholesky = cholesky(cov)
        except LinAlgError:
            raise ValueError("Covariance is not SPD!")
        self.norm = multivariate_normal(self.y, cov=self.cov)

    def __len__(self):
        return len(self.x)


class MultiGaussianData(object):
    """
    
    Parameters
    ----------
    data_list : list
        List of Data objects
        
    cross_covs : dictionary
        Cross-covariances, keyed by (name1, name2) tuples.
    """

    def __init__(self, data_list, cross_covs=None):

        if cross_covs is None:
            cross_covs = {}

        # Ensure all cross-covs are proper shape, and fill with zeros if not present
        for d1 in data_list:
            for d2 in data_list:
                key = (d1.name, d2.name)

                if d1 == d2:
                    cross_covs[key] = d1.cov

                rev_key = (d2.name, d1.name)
                if key in cross_covs:
                    cov = cross_covs[key]
                    if not cov.shape == (len(d1), len(d2)):
                        raise ValueError(
                            f"Cross-covariance (for {d1.name} x {d2.name}) has wrong shape: {cov.shape}!"
                        )
                elif rev_key in cross_covs:
                    cross_covs[key] = cross_covs[rev_key].T
                else:
                    cross_covs[key] = np.zeros((len(d1), len(d2)))

        self.data_list = data_list
        self.lengths = [len(d) for d in data_list]
        self.names = [d.name for d in data_list]
        self.cross_covs = cross_covs

        self._data = None

    @property
    def data(self):
        if self._data is None:
            self._assemble_data()
        return self._data

    def _index_range(self, name):
        if name not in self.names:
            raise ValueError(f"{name} not in {self.names}!")

        i0 = 0
        for n, length in zip(self.names, self.lengths):
            if n == name:
                i1 = i0 + length
                break
            i0 += length
        return i0, i1

    def _slice(self, *names):
        if isinstance(names, str):
            names = [names]

        return np.s_[tuple(slice(*multi._index_range(n)) for n in names)]

    def _assemble_data(self):
        x = np.concatenate([d.x for d in self.data_list])
        y = np.concatenate([d.y for d in self.data_list])

        N = sum([len(d) for d in self.data_list])

        cov = np.zeros((N, N))
        for n1 in self.names:
            for n2 in self.names:
                cov[self._slice(n1, n2)] = self.cross_covs[(n1, n2)]

        self._data = Data(" + ".join(self.names), x, y, cov)

In [8]:
from sklearn.datasets import make_spd_matrix

name1 = "A"
n1 = 10
x1 = np.arange(n1)
y1 = np.random.random(n1)

name2 = "B"
n2 = 20
x2 = np.arange(n2)
y2 = np.random.random(n2)

name3 = "C"
n3 = 30
x3 = np.arange(n3)
y3 = np.random.random(n3)

# Generate arbitrary covariance matrix, partition into parts
full_cov = make_spd_matrix(n1 + n2 + n3, random_state=1234)
cov1 = full_cov[:n1, :n1]
cov2 = full_cov[n1 : n1 + n2, n1 : n1 + n2]
cov3 = full_cov[n1 + n2 :, n1 + n2 :]

data1 = GaussianData(name1, x1, y1, cov1)
data2 = GaussianData(name2, x2, y2, cov2)
data3 = GaussianData(name3, x3, y3, cov3)

cross_cov = {
    (name1, name2): full_cov[:n1, n1 : n1 + n2],
    (name1, name3): full_cov[:n1, n1 + n2 :],
    (name2, name3): full_cov[n1 : n1 + n2, n1 + n2 :],
}

In [9]:
multi = MultiGaussianData([data1, data2, data3], cross_cov)

assert (multi.cross_covs[(name1, name2)] == multi.cross_covs[(name2, name1)].T).all()
assert (multi.cross_covs[(name1, name3)] == multi.cross_covs[(name3, name1)].T).all()
assert (multi.cross_covs[(name2, name3)] == multi.cross_covs[(name3, name2)].T).all()

assert (multi.cross_covs[(name1, name1)] == data1.cov).all()
assert (multi.cross_covs[(name2, name2)] == data2.cov).all()
assert (multi.cross_covs[(name3, name3)] == data3.cov).all()

In [60]:
multi._slice('A', 'B')

(slice(0, 10, None), slice(10, 30, None))

In [51]:
multi._slice('B', 'C')

(slice(10, 30, None), slice(30, 60, None))

In [15]:
multi._index_range('A')

(0, 10)

In [16]:
multi._index_range('B')

(10, 30)

In [17]:
multi._index_range('C')

(30, 60)

In [30]:
np.s_[(slice(*multi._index_range('A')), slice(*multi._index_range('B')))]

(slice(0, 10, None), slice(10, 30, None))

In [36]:
np.s_[tuple(slice(*multi._index_range(n)) for n in ['A', 'B'])]

(slice(0, 10, None), slice(10, 30, None))

In [36]:
from scipy.stats import multivariate_normal

In [38]:
norm1 = multivariate_normal(y1, cov=cov1)

In [33]:
assert 
multi.cross_covs.keys()

dict_keys([('A', 'B'), ('B', 'A')])