In [1]:
import os,sys,inspect
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0,parentdir) 
from functools import lru_cache
import genomicsurveillance as gs

In [2]:
from genomicsurveillance.data import get_geodata, get_meta_data
from genomicsurveillance.utils import create_spline_basis, create_dates_df
from genomicsurveillance.gov_uk import get_specimen
from genomicsurveillance.handler import SVIModel
from genomicsurveillance.models import Sites
from genomicsurveillance.distributions import *
import numpy as np
import numpyro as npy

In [3]:
from genomicsurveillance.handler import SVIHandler

In [4]:
import pandas as pd

## Get metadata

In [5]:
uk = get_meta_data()

In [6]:
eng = uk[uk.ctry19nm == 'England']

## Get cases

In [7]:
# cases = get_specimen() 
cases = pd.read_csv('specimen-20210304.csv', index_col=0)

In [8]:
cases = eng.merge(cases, left_on='lad19cd', right_index=True, how='left').iloc[:, -184:-4].values

## Lineage data

In [9]:
lineages = (pd.read_csv('latest.txt', index_col=0)
            .melt("WeekEndDate", var_name="loc", value_name="samples")
            .assign(location=lambda df: df["loc"].apply(lambda x: x.split("_")[0]))
            .assign(lineage=lambda df: df["loc"].apply(lambda x: "_".join(x.split("_")[1:])))
            .drop("loc", 1))

In [10]:
lineage_dates = lineages.WeekEndDate.unique().tolist()

In [11]:
lineages_types = lineages.lineage.unique().tolist()

In [12]:
lineages = np.stack([
    (lineages[lineages.lineage == lineages_types[i]]
     .pivot(values="samples", index="location", columns="WeekEndDate")
     .merge(eng, left_index=True, right_on='lad19cd', how='right')
     .loc[:, lineage_dates[2:]]
     .values
    )
     for i in range(0, len(lineages_types))], -1
)

In [13]:
lineage_date_idx = np.array([create_dates_df(cases.shape[1], '2020-09-01').index(date) for date in lineage_dates[2:]])

In [14]:
max_lineages = 10

In [15]:
lineages_red = np.concatenate([lineages[..., :max_lineages], lineages[..., max_lineages:].sum(-1, keepdims=True)], -1)

In [16]:
def isnot_nan_row(array: np.ndarray) -> np.ndarray:
    return np.where(~np.isnan(array.sum(axis=tuple(range(1, array.ndim)))))[0]
def is_nan_row(array: np.ndarray) -> np.ndarray:
    return np.where(np.isnan(array.sum(axis=tuple(range(1, array.ndim)))))[0]

In [17]:



class IndependentMultiLineage(SVIModel):

    def __init__(
        self,
        cases,
        lineages,
        lineage_dates,
        population,
        basis=None,
        tau=6.5,
        init_scale=0.1,
        fit_a=False,
        fit_rho=False,
        beta_loc=-10.0,
        beta_scale=2.0,
        mu_a_scale=0.001,
        mu_b_scale=0.01,
        a_scale=0.001,
        b_scale=0.01,
        c_scale=5.0,
        rho_loc=np.log(25.0),
        rho_scale=1.0,
        multinomial_scale=1.0,
        time_scale=100.0,
        exclude=False,
        use_correlation=False,
        use_ar=False,
        *args,
        **kwargs
    ):
        """
        RelaxedMultiLineage model
        """
        assert (cases.shape[0] == lineages.shape[0]), "cases and lineages must have the number of location"
        super().__init__(**kwargs)
        self.cases = cases
        self.lineages = lineages
        self.lineage_dates = lineage_dates
        self.population = population
        
        self.tau = tau
        self.init_scale = init_scale
        self.fit_rho = fit_rho
        self.fit_mu_b_mvn = False

        self.beta_loc = beta_loc
        self.beta_scale = beta_scale

        self.mu_b_loc = 0.0
        self.mu_b_scale = mu_b_scale

        self.b_loc = 0.0
        self.b_scale = b_scale
        self.c_loc = -10.0
        self.c_scale = c_scale

        self.rho_loc = rho_loc
        self.rho_scale = rho_scale
        self.exclude = exclude

        self.use_correlation = use_correlation
        self.use_ar = use_ar

        self.multinomial_scale = multinomial_scale
        self.time_scale = time_scale
        

        
        if basis is None:
            _, self.B = create_spline_basis(np.arange(cases.shape[1]), num_knots=int(np.ceil(cases.shape[1] / 10)))
        else:
            self.B = basis
            
    @property
    @lru_cache()
    def _nan_idx(self):
        exclude = list(set(is_nan_row(self.lineages)) | set(is_nan_row(self.cases)))
        return np.array([i for i in range(self.cases.shape[1]) if i not in exclude])
    
    @property
    @lru_cache()
    def _missing_lineages(self):
        return (self.lineages[..., :-1].sum(1) != 0)[self._nan_idx].astype(int)

    def _pad(self, name, array, index, shape, func):
        """Expands the array to the full size shape"""
        expanded_array = jnp.zeros(shape)
        expanded_array = index_update(expanded_array, index, array)
        return npy.deterministic(
            name, jnp.concatenate([expanded_array, func((shape[0], 1, 1))], -1)
        )
    
    def aggregate(self):
        # Regression coefficients (num_location x num_basis)
        beta_1 = npy.sample(
            self.BETA_1,
            dist.MultivariateNormal(
                self.beta_loc,
                jnp.tile(self.beta_scale, (num_countries, num_basis, 1))
                * jnp.eye(num_basis).reshape(1, num_basis, num_basis),
            ),
        )
        N_regions = jnp.array([N[C == i].sum(0) for i in range(num_countries)])
        f_regions_lineage = jnp.stack(
            [jnp.nansum(f_lin[C == i], 0) for i in range(num_countries)], 0
        )
        f_regions = f_regions_lineage.sum(-1, keepdims=True)
        p_regions = npy.deterministic("p_regions", f_regions_lineage / f_regions)
        l_regions = jnp.stack(
            [jnp.nansum(lineage_red[C == i], 0) for i in range(num_countries)], 0
        )

        g_regions = jnp.exp(beta_1 @ B[0].T)
        print("g_regions", g_regions.shape)
        print("f_regions", f_regions.shape)
        mu_regions = g_regions * f_regions.squeeze()

        lamb_regions_lineage = npy.deterministic(
            "lamb_regions_1",
            N_regions.reshape(-1, 1, 1)
            * g_regions[..., jnp.newaxis]
            * f_regions_lineage,
        )
        lamb_regions = npy.deterministic(
            "lamb_regions", N_regions.reshape(-1, 1) * mu_regions
        )
        #         R_regions = npy.deterministic('R_regions', jnp.exp(((beta_1 @ B[1].T)[..., jnp.newaxis] + b_1) * self.tau))
        b_regions = jnp.stack(
            [jnp.nansum(b_1[C == i], 0) for i in range(num_countries)], 0
        )
        print("b_regions", b_regions.shape)
        R_regions = npy.deterministic(
            "R_regions",
            jnp.exp(((beta_1 @ B[1].T)[..., jnp.newaxis] + b_regions) * self.tau),
        )

        print("lamb_regions_lin", lamb_regions_lineage.shape)
        print("lamb_regions", lamb_regions.shape)
        npy.sample(
            "lineage_regions",
            MultinomialProbs(
                p_regions[:, X],
                total_count=l_regions.sum(-1),
                scale=self.multinomial_scale,
            ),
            obs=l_regions,
        )

        specimen_regions = jnp.stack(
            [jnp.nansum(specimen[C == i], 0) for i in range(num_countries)], 0
        )
        print("spe", specimen_regions.shape)
        npy.sample(
            "specimen_regions",
            NegativeBinomial(lamb_regions, jnp.exp(rho)),  # jnp.clip(, 1e-6, 1e6)
            obs=specimen_regions,
        )

    def model(self):
        """The model."""
        
        num_ltla = self.cases.shape[0]
        num_time = self.cases.shape[1]
        num_lin = self.lineages.shape[-1] - 1
        num_basis = self.B.shape[-1]
        num_ltla_lin = self._nan_idx.shape[0]
        num_time_lin = self.lineages.shape[1]
        
        plate_time = npy.plate("time", num_time, dim=-1)
        plate_ltla = npy.plate("ltla", num_ltla, dim=-2)
        plate_basis = npy.plate("basis", num_basis, dim=-1)

        plate_lin = npy.plate("lin", num_lin, dim=-1)
        plate_lin_time = npy.plate("lin_time", num_time_lin, dim=-2)
        plate_lin_ltla = npy.plate("lin_ltla", num_ltla_lin, dim=-3)

        # dispersion parameter for lads
        if self.fit_rho:
            with plate_ltla:
                rho = npy.sample(Sites.RHO, dist.Normal(self.rho_loc, self.rho_scale))
        else:
            with plate_ltla:
                rho = self.rho_loc

        if self.use_correlation:
            print("Correlation")
            Σ0 = jnp.eye(num_basis)
            for i in range(1, num_basis):
                Σ0 = index_update(
                    Σ0, index[i, i - 1], jnp.array(0.5)
                )  ## correlate neighbouring basis functions

        if self.use_ar:
            print("AR")
            Π0 = jnp.linalg.inv(Σ0)  ### THIS NEEDS TO GO HERE
            for i in range(num_basis - 3, num_basis):
                Π0 = index_update(Π0, index[i, i - 2 : i], jnp.array([1, -2]))
            Π0 = index_update(
                Π0,
                index[num_basis - 3, num_basis - 5 : num_basis - 3],
                0.5 * jnp.array([1, -2]),
            )  ## Make last 3 basis functions autoregressive
            Σ0 = jnp.linalg.inv(Π0)

        # Regression coefficients (num_location x num_basis)
        beta_0 = npy.sample(
            Sites.BETA_0,
            dist.MultivariateNormal(
                self.beta_loc,
                jnp.tile(self.beta_scale, (num_ltla, num_basis, 1))
                * jnp.eye(num_basis).reshape(1, num_basis, num_basis),
            ),
        )

        
        if self.use_correlation or self.use_ar:
            β_0 = npy.sample("β_0", dist.Uniform(-30, 0, num_ltla))
            beta_0 = npy.deterministic("beta_0", beta_0 @ Σ0.T + β_0)

        # lineage priors
        with plate_lin:
            if self.fit_mu_b_mvn:
                mu_b = npy.sample(
                    Sites.MU_B,
                    dist.MultivariateNormal(
                        self.mu_b_loc, self.mu_b_scale * jnp.eye(num_lin)
                    ),
                )
            else:
                mu_b = npy.sample(
                    Sites.MU_B,
                    dist.Normal(
                        self.time_scale * self.mu_b_loc,
                        self.time_scale * self.mu_b_scale,
                    ),
                )

        mu_bc = jnp.concatenate(
            [
                self._missing_lineages * jnp.repeat(mu_b.reshape(1, -1), num_ltla_lin, 0),
                jnp.repeat(
                    jnp.repeat(self.c_loc, num_lin).reshape(1, -1), num_ltla_lin, 0
                ),
            ],
            -1,
        )
        
        sd_bc = jnp.repeat(
            jnp.diag(
                jnp.concatenate(
                    [
                        self.time_scale * jnp.repeat(self.b_scale, num_lin),
                        jnp.repeat(self.c_scale, num_lin),
                    ]
                )
            ).reshape(1, 2 * num_lin, 2 * num_lin),
            num_ltla_lin,
            0,
        )
        bc = npy.sample("bc", dist.MultivariateNormal(mu_bc, sd_bc)).reshape(
            num_ltla_lin, 2, num_lin
        )
        
        # pad lineage parameters b, c to match the full size array
        b_1 = self._pad(
            Sites.B_1,
            bc[:, 0].reshape(num_ltla_lin, 1, -1) / self.time_scale,
            index[self._nan_idx, :, :],
            (num_ltla, 1, num_lin),
            jnp.zeros,
        )
        c_1 = self._pad(
            Sites.C_1,
            bc[:, 1].reshape(num_ltla_lin, 1, -1),
            index[self._nan_idx, :, :],
            (num_ltla, 1, num_lin),
            jnp.zeros,
        )

        # Lineage specific regression coefficients (num_ltla x num_basis x num_lin)
        f_lin = jnp.exp(b_1 * jnp.arange(num_time).reshape(1, -1, 1) + c_1)
        f = f_lin.sum(-1, keepdims=True)

        p = npy.deterministic(Sites.P, f_lin / f)
        print("p", p.shape)

        if self.exclude:
            print("excluding missing lineages")
            p = index_update(p, exclude, 0.0)

        g = jnp.exp(beta_0 @ self.B[0].T)
        mu = g * f.squeeze()
        lamb_1 = npy.deterministic(
            Sites.LAMB_1, self.population.reshape(-1, 1, 1) * g[..., jnp.newaxis] * f_lin
        )
        lamb = npy.deterministic(Sites.LAMB, self.population.reshape(-1, 1) * mu)
        G_r = npy.deterministic("G", beta_0 @ self.B[1].T)
        R_1 = npy.deterministic(
            Sites.R_1, jnp.exp(((beta_0 @ self.B[1].T)[..., jnp.newaxis] + b_1) * self.tau)
        )

        sa = npy.deterministic("sa", jnp.exp(mu_b / self.time_scale * self.tau))


        npy.sample(
            Sites.SPECIMEN,
            NegativeBinomial(lamb[self._nan_idx], jnp.exp(rho)),  # jnp.clip(, 1e-6, 1e6)
            obs=self.cases[self._nan_idx],
        )

        # with lineage_context:
        npy.sample(
            Sites.LINEAGE,
            MultinomialProbs(
                p[self._nan_idx][:, self.lineage_dates], total_count=self.lineages[self._nan_idx].sum(-1), scale=self.multinomial_scale
            ),
            obs=self.lineages[self._nan_idx],
        )

    def guide(self):
        num_ltla = self.cases.shape[0]
        num_time = self.cases.shape[1]
        num_lin = self.lineages.shape[-1] - 1
        num_basis = self.B.shape[-1]
        num_ltla_lin = self._nan_idx.shape[0]
        num_time_lin = self.lineages.shape[1]

        if self.fit_rho:
            rho_loc = npy.param(
                Sites.RHO + Sites.LOC,
                self.rho_loc * jnp.ones((num_ltla, 1)),
            )
            rho_scale = npy.param(
                Sites.RHO + Sites.SCALE,
                self.init_scale * self.rho_scale * jnp.ones((num_ltla, 1)),
                constraint=dist.constraints.positive,
            )
            rho = npy.sample(Sites.RHO, dist.Normal(rho_loc, rho_scale))

        # mean / sd for parameter s
        beta_0_loc = npy.param(
            Sites.BETA_0 + Sites.LOC, self.beta_loc * jnp.ones((num_ltla, num_basis))
        )
        beta_0_scale = npy.param(
            Sites.BETA_0 + Sites.SCALE,
            self.init_scale
            * self.beta_scale
            * jnp.stack(num_ltla * [jnp.eye(num_basis)]),
            constraint=dist.constraints.lower_cholesky,
        )

        # cov = jnp.matmul(β_σ, jnp.transpose(β_σ, (0, 2, 1)))
        beta_0 = npy.sample(
            Sites.BETA_0, dist.MultivariateNormal(beta_0_loc, scale_tril=beta_0_scale)
        )  # cov

        # mean / sd for parameter s
#         beta_1_loc = npy.param(
#             Sites.BETA_1 + Sites.LOC, self.beta_loc * jnp.ones((num_countries, num_basis))
#         )
#         beta_1_scale = npy.param(
#             Sites.BETA_1 + Sites.SCALE,
#             self.init_scale
#             * self.beta_scale
#             * jnp.stack(num_countries * [jnp.eye(num_basis)]),
#             constraint=dist.constraints.lower_cholesky,
#         )

#         # cov = jnp.matmul(β_σ, jnp.transpose(β_σ, (0, 2, 1)))
#         beta_1 = npy.sample(
#             Sites.BETA_1, dist.MultivariateNormal(beta_1_loc, scale_tril=beta_1_scale)
#         )  # cov

        mu_b_loc = npy.param(Sites.MU_B + Sites.LOC, jnp.repeat(self.mu_b_loc, num_lin))

        if self.fit_mu_b_mvn:
            mu_b_scale = npy.param(
                Sites.MU_B + Sites.SCALE,
                jnp.diag(self.init_scale * self.mu_b_scale * jnp.ones(num_lin)),
                constraint=dist.constraints.lower_cholesky,
            )
            mu_b = npy.sample(
                Sites.MU_B, dist.MultivariateNormal(mu_b_loc, scale_tril=mu_b_scale)
            )
        else:
            mu_b_scale = npy.param(
                Sites.MU_B + Sites.SCALE,
                self.init_scale * self.mu_b_scale * self.time_scale * jnp.ones(num_lin),
                constraint=dist.constraints.positive,
            )
            mu_b = npy.sample(Sites.MU_B, dist.Normal(mu_b_loc, mu_b_scale))

        bc_loc = npy.param(
            "bc_loc",
            jnp.repeat(
                jnp.concatenate(
                    [self.b_loc * jnp.ones(num_lin), self.c_loc * jnp.ones(num_lin)]
                ).reshape(1, -1),
                num_ltla_lin,
                0,
            ),
        )
        bc_scale = npy.param(
            "bc_scale",
            jnp.repeat(
                jnp.diag(
                    jnp.concatenate(
                        [
                            self.init_scale
                            * self.b_scale
                            * self.time_scale
                            * jnp.ones(num_lin),
                            self.init_scale * self.c_scale * jnp.ones(num_lin),
                        ]
                    )
                ).reshape(1, 2 * num_lin, 2 * num_lin),
                num_ltla_lin,
                0,
            ),
            constraint=dist.constraints.lower_cholesky,
        )

        npy.sample("bc", dist.MultivariateNormal(bc_loc, scale_tril=bc_scale))


In [18]:
model = IndependentMultiLineage(cases, lineages_red, lineage_date_idx, eng.pop18.values)

In [19]:
model.fit()

p (317, 180, 11)
p (317, 180, 11)
p (317, 180, 11)


In [20]:
import matplotlib.pyplot as plt

In [None]:
plt.semilogy(model.loss)

In [44]:
model.posterior.mean('beta_0', 0)

DeviceArray([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
             nan, nan, nan, nan, nan, nan, nan, nan, nan], dtype=float32)