# VAR GLLVM

 ## Model Specification

Let  $y_{i1t},y_{i2t},\ldots,y_{ipt}$  be a set of $p$ response or observed variables at time $t,\ t=1,\ldots,T$ for  individual $i,\ i=1,\ldots,n$. Let $\mathbf{x}_{it}$ be a set of observed $k$-dimensional covariates at time $t,\ t=1,\ldots,T$.

Models for multivariate longitudinal data have to account for the three sources of variability
present in the data, that is (i) cross-sectional associations between the responses at a particular time point, (ii) cross-lagged
associations between different responses at different occasions, and  (iii) the association between repeated measures of the same response
over time. The first source of variability is accounted for
a time-dependent latent variable $z_{i1}, z_{i2},\ldots,z_{iT}$. Modeling the temporal evolution of the latent variable acpoisson for the cross-lagged associations between different responses over time.
The third source of variability can be accounted for a set of item-specific random effects $\mathbf{u}_{i}=(u_{i1}, \ldots, u_{ip})'$.

According to the GLLVM framework we have

\begin{align*}
   \nonumber y_{ijt}|\mu_{ijt} &\sim \mathcal{F}_j(y_{ijt}\vert \mu_{ijt}, \tau_j)\\
   \mu_{ijt}&=  g_j(\eta_{ijt})=g_j(\beta_{0jt} + \mathbf{x}_{i}^{\top}\boldsymbol \beta_{jt} + z_{it}^{\top}\lambda_{jt}+u_{ij}\sigma_{u_j})\\ %  \label{eqn:GLLVM-model2}
\end{align*}
where $g_j(\cdot),j=1,\ldots,p$ is a known {\it link function}, $\eta_{ijt}=\beta_{0jt} + \mathbf{x}_{i}^{\top}\boldsymbol \beta_{jt} + z_{it}^{\top}\lambda_{jt}+u_{ij},i=1,\ldots,n,j=1,\ldots,p, t=1,\ldots,T$ is the {\it linear predictor},  and $\mathcal{F}_j(y_{ijt}\vert \eta_{ijt}, \tau_j)$ denotes a distribution from the exponential family with mean $\mu_{ijt}$ and response variable-specific dispersion parameter $\tau_j$. \vspace{5pt}\\
The dynamics of the latent variable over time is modelled through a stationary vector-autoregressive model of first order


$$
z_{i,t} = Az_{i,t-1} +  \epsilon_{i,t}\\
||A||_2 < 1\\
\epsilon_{i,t} \sim N(0, I)\\
$$

initialization for $t=0$:

$$
z_{i, 1} \sim N(0, \Sigma_{z1})\\
\Sigma_{z1} = 
\begin{pmatrix}
    \sigma_{z1,1}^2&\\
    &\sigma_{z1,2}^2
\end{pmatrix}\\
\sigma_{z1,i}>0 \forall i
$$

Moreover, we assume the random effects independent of the latent variable and their common distribution $\mathbf{u}_{i}\sim N_p(\mathbf{0}, \boldsymbol I)$.



# TODO:

* check if there is a need to add to define seq_length... i think not: no parameter depends on seq_length. for now we leave it because the encoder may need that, even though I dont think so.
* test that the operations in forward are correct

In [66]:
import torch
import torch.nn as nn
from VAR1 import VAR1

class VARGLLVM(nn.Module):
    """
    Vector Autoregressive Generalized Linear Latent Variable Model (VAR GLLVM) for multivariate longitudinal data.
    
    This class provides a multivariate extension of GLLVM, modeling multivariate responses over time.
    The variability in the data is captured through three sources: cross-sectional associations,
    cross-lagged associations, and associations between repeated measures of the same response over time.
    The model specification integrates a time-dependent latent variable and an item-specific random effect.
    
    Parameters:
    - num_var: int, Number of response variables.
    - num_latent: int, Number of latent variables.
    - num_covar: int, Number of observed covariates.
    - response_types: dict, Mapping of response type to its indices.
    - add_intercepts: bool, Whether to include intercepts in the model.
    
    Key Methods:
    - forward: Computes the conditional mean of the model.
    - sample: Draws samples from the VARGLLVM model.
    - sample_response: Samples from the response distribution based on conditional mean.
    - linpar2condmean: Converts linear predictors to conditional means.

    Model Specification:
    y_{ijt}|μ_{ijt} ~ F_j(y_{ijt}|μ_{ijt}, τ_j)
    μ_{ijt} = g_j(β_{0jt} + x_{it}^T*β_{jt} + z_{it}^T*λ_{jt}+u_{ij}*σ_{u_j})
    
    Where:
    - g_j: link function
    - μ_{ijt}: mean of the distribution
    - F_j: a distribution from the exponential family
    - η_{ijt}: linear predictor

    Temporal evolution of the latent variable:
    z_{it} = A*z_{i,t-1} + ε_{it}
    with ε_{it} ~ N(0, I) and initialization z_{i, 1} ~ N(0, Σ_{z1}).
    
    The random effects are assumed independent of the latent variable and distributed as:
    u_{i} ~ N_p(0, I).
    """
    
    def __init__(self, num_var, num_latent, num_covar, response_types, add_intercepts = True):
        super().__init__()
        self.response_types =  response_types
        self.response_link = {
            'bernoulli' : lambda x: torch.logit(x),
            'ordinal': lambda x: torch.logit(x),
            'poisson': lambda x: torch.log(x)
        }
        self.response_linkinv = {
            'bernoulli': lambda x: 1/(1+torch.exp(-x)),
            'ordinal': lambda x: 1/(1+torch.exp(-x)),
            'poisson': lambda x: torch.exp(x)
        }
        self.response_transform = {
            'bernoulli' : lambda x: 2*x - 1,
            'ordinal': lambda x: 2*x - 1,
            'poisson': lambda x: torch.log(x+1)
        }

        self.num_var = num_var
        self.num_latent = num_latent
        self.num_covar = num_covar

        # Define Parameters
        # -----------------
        # Parameters for the VAR
        self.A = nn.Parameter(torch.zeros((num_latent, num_latent)))
        self.logvar_z1 = nn.Parameter(torch.zeros((num_latent,)))

        # Parameters for the outcome model
        if add_intercepts:
            self.intercepts = nn.Parameter(torch.zeros((num_var,)))
        else:
            self.intercepts = None

        self.wz = nn.Parameter(torch.randn((num_latent, num_var)))
        self.wx = nn.Parameter(torch.randn((num_covar, num_var)))

        # Parameters for the random effects
        self.logvar_u = nn.Parameter(torch.zeros((num_var,)))
        
        # Define Modules
        # --------------
        self.VAR1 = VAR1(A=self.A, logvar_z1 = self.logvar_z1)

    def forward(self, epsilon, u, x = None):
        """
        Compute the conditional mean of a VARGLLVM

        Parameters:
            - epsilon: shocks for the VAR
            - u: shocks for the random effects
            - x: covariates
        """
        assert epsilon.shape[2] == self.num_latent, "bad shape for epsilon"
        assert u.shape[1:] == (1, self.num_var), "bad shape for u"

        # Computing linpar
        # ----------------
        linpar = torch.zeros((epsilon.shape[0], epsilon.shape[1], self.num_var))

        # add intercepts, one per variable
        if self.intercepts is not None:
            linpar += self.intercepts[None, None, :]

        # add covariates' effects
        if x is None:
            assert self.num_covar == 0, f'VARGLLVM module expected {self.num_covar} covariates, received {0}.'
        else:
            assert(x.shape[1:] == (epsilon.shape[1], self.num_covar))
            linpar += x @ self.wx

        # add latent variables' effects
        z = self.VAR1(epsilon)
        linpar += z @ self.wz

        # finally, add random effects
        linpar += u * torch.sqrt(torch.exp(self.logvar_u[None, None, :])) # we add a time dimension: u is the same across time!
        
        # compute the conditional mean
        condmean = self.linpar2condmean(linpar)
        return (linpar, condmean)
    
    def sample(self, num_batch, seq_length,  x = None, epsilon = None, u = None):
        """
        Sample from the VARGLLVM. 

        Parameters:
            - num_batch: number of observational units
            - seq_length: length of the sequence. if any of x, epsilon, or u are provided, their seq_length must coincide
            - x: tensor of shape (num_batch, seq_length, num_covar). If the VARGLLVM model was initialized with num_covar >= 1, cannot be None.
            - epsilon: the shocks for the latent variables of shape (num_batch, seq_length, num_latent). If None, those are drawn iid from N(0, 1).
            - u: the shocks for the random effects of shape (num_batch, 1, num_var). If None, those are drawn iid from N(0, 1).

        """
        if epsilon is None:
            epsilon = torch.randn((num_batch, seq_length, self.num_latent))
        if u is None:
            u = torch.randn((num_batch, 1, self.num_var)) # one per var, but constant across time
        
        linpar, condmean = self(epsilon, u, x)
        y = self.sample_response(condmean)

        return (linpar, condmean, y)

    def sample_response(self, mean):
        device = next(self.parameters()).device
        y = torch.zeros_like(mean).to(device)
        for response_type, response_id in self.response_types.items():
            if response_type == "bernoulli":
                y[:,:,response_id] = torch.bernoulli(mean[:,:,response_id]).to(device)
            elif response_type == "ordinal":
                cum_probs = mean[:,:,response_id]
                # draw one uniform for the whole vector
                random = torch.rand((*cum_probs.shape[0:2], 1)).to(device)
                # compare with the cumulative probabilities
                ordinal = torch.sum(random > cum_probs, dim=2)
                ordinal = torch.nn.functional.one_hot(ordinal).squeeze().float()
                ordinal = ordinal[:,:,1:] # discard the first column of the one_hot encoding, as it is superfluous (as a 0)
                y[:,:,response_id] = ordinal
            elif response_type == "poisson":
                y[:,:,response_id] = torch.poisson(mean[:,:,response_id])
        return y
    
    def linpar2condmean(self, linpar):
        mean  = torch.zeros_like(linpar)
        for response_type, response_id in self.response_types.items():
            mean[:,:,response_id] = self.response_linkinv[response_type](linpar[:,:,response_id])
        return mean
    def _get_device(self):
        return next(self.parameters()).device


In [64]:
n = 1000
num_var = 13
num_latent = 2
num_covar = 0
seq_length = 10
x = torch.randn((n, seq_length, num_covar))

response_types = {
    'bernoulli': [0,1,2,3,4,5,6,7, 9, 10],
    'poisson': [8, 11, 12]
}

model = VARGLLVM(num_var = num_var,
                 num_latent = num_latent,
                 num_covar = num_covar,
                 response_types = response_types,
                 add_intercepts=False)

In [None]:
model.sample()