In [4]:
import jax
import jax.numpy as jnp
from ml_collections import config_flags
import numpy as np
from absl import app, flags

from tensorflow.keras.datasets import mnist

import math, os
from typing import Any, Callable, Mapping, Sequence, Tuple, Optional
from jax import Array
from flax import linen as nn
from flax.training import train_state, orbax_utils
import optax
import wandb
from orbax.checkpoint import CheckpointManager, PyTreeCheckpointer
from pathlib import Path
import tqdm
from einops import rearrange

from ddprism import training_utils
from ddprism.corrupted_mnist import datasets
from ddprism.metrics import metrics, image_metrics

'''
import config_base_grass
import config_base_mnist
imagenet_path = '/mnt/home/aakhmetzhanova/ceph/galaxy-diffusion/corrupted-mnist/dataset/grass_jpeg/'
'''

"\nimport config_base_grass\nimport config_base_mnist\nimagenet_path = '/mnt/home/aakhmetzhanova/ceph/galaxy-diffusion/corrupted-mnist/dataset/grass_jpeg/'\n"

In [5]:
class cLVM(nn.Module):
    r"""Creates an instance of cLVM model.
    
    Arguments:
        model: Model used for cLVM method. 
    """
    model: nn.Module

    def sample_latents(
        self, rng: Array, model: nn.Module, x: Array, y: Array, a_mat: Optional[Array] = None
    ) -> Tuple:
        # Sample latent variables for the target and background datasets.
        rng_zx, rng_zy, rng_tx = jax.random.split(rng, 3)
        
        # Compute mean and std for the target and background datasets.
        mu_zx, log_sigma_zx, mu_tx, log_sigma_tx = model.encode_target(x, a_mat)
        mu_zy, log_sigma_zy = model.encode_bkg(x, a_mat)
        
        # Sample latent variables corresponding to the enriched signal in the target dataset.
        eps_tx = jax.random.normal(rng_tx, shape=x.shape[0] + (model.target_latent_dim,))
        tx = mu_tx + jax.exp(log_sigma_tx) * eps_tx

        # Sample latent variables corresponding to the background in the target dataset.
        eps_zx = jax.random.normal(rng_zx, shape=x.shape[0] + (model.bkg_latent_dim,))
        zx = mu_zx + jax.exp(log_sigma_zx) * eps_zx

        # Sample latent variables corresponding to the background in the background dataset.
        eps_zy = jax.random.normal(rng_zy, shape=y.shape[0] + (model.bkg_latent_dim,))
        zy = mu_zy + jax.exp(log_sigma_zy) * eps_zy
        
        return tx, zx, zy

    def expect_data(
        self, model: nn.Module, tx: Array, zx: Array, zy: Array, a_mat: Optional[Array] = None
    ) -> Tuple:
        # Compute expected values of the target and background datasets.
        x = model.decode_bkg(zx, a_mat) + model.decode_target(tx, a_mat)
        y = model.decode_bkg(zy, a_mat)
        return x, y

    @nn.compact
    def __call__(
        self, rng: Array, model: nn.Module, x: Array, y: Array, a_mat: Optional[Array] = None
    ) -> Tuple:
        # Generate new samples of the data.
        # Sample latent variables.
        tx, zx, zy = self.sample_latents(rng, model, x, y, a_mat)
        # Compute expectation for the data, given the latents.
        x, y = self.expect_data(model, tx, zx, zy, a_mat)
        return x, y
    

In [9]:
class clvmLinear(nn.Module):
    r"""Creates an instance of cLVM model that linearly maps latent variables to the observed space.
    """

    def sample_latents(
        self, rng: Array, model_params: Tuple, x: Array, y: Array, a_mat: Optional[Array] = None
    ) -> Tuple:
        pass

    def expect_data(
        self, model_params: Tuple, tx: Array, zx: Array, zy: Array, a_mat: Optional[Array] = None
    ) -> Tuple:
        
        w_mat, s_mat, mu_x, mu_y = model_params
        # Compute expected values of the target and background datasets.
        x = s_mat @ zx + w_mat @ tx + mu_x
        y = s_mat @ zy + mu_y

        if a_mat is not None:
            x = a_mat @ x
            y = a_mat @ y
        return x, y

    @nn.compact
    def __call__(
        self, rng: Array, x: Array, y: Array, a_mat: Optional[Array] = None
    ) -> Tuple: 
        # Following GaussianDenoiserDPLR model?
        # Initialize model parameters.
        mu_x = self.param(
            "mu_x", lambda rng, shape: jnp.ones(shape), x.shape[-1:]
        )

        mu_y = self.param(
            "mu_x", lambda rng, shape: jnp.ones(shape), y.shape[-1:]
        )
        
        s_mat = self.param(
            "s_mat",
            lambda rng, shape: ...,
            (..., )
        )

        w_mat = self.param(
            "w_mat",
            lambda rng, shape: ...,
            (..., )
        )

        model_params = (w_mat, s_mat, mu_x, mu_y)
        
        # Generate new samples of the data.
        # Sample latent variables.
        tx, zx, zy = self.sample_latents(rng, model, x, y, a_mat)
        
        # Compute expectation for the data, given the latents.
        x, y = self.expect_data(model_params, tx, zx, zy, a_mat)
        return x, y
    

In [10]:
class clvmVAE(clvmLinear):
    r"""Creates an instance of cLVM model that non-linearly maps latent variables to the observed space.
    """
    vae_z: nn.Module
    vae_t: nn.Module

    def sample_latents(
        self, rng: Array, x: Array, y: Array, a_mat: Optional[Array] = None
    ) -> Tuple:
        # Sample latent variables for the target and background datasets.
        rng_zx, rng_zy, rng_tx = jax.random.split(rng, 3)
        
        # Compute mean and std for the target and background datasets.
        mu_tx, log_sigma_tx = self.vae_t.encoder(x, a_mat)

        mu_zx, log_sigma_zx = self.vae_z.encoder(x, a_mat)
        mu_zy, log_sigma_zy = self.vae_z.encoder(y, a_mat)
        
        # Sample latent variables corresponding to the enriched signal in the target dataset.
        eps_tx = jax.random.normal(rng_tx, shape=x.shape[0] + (model.target_latent_dim,))
        tx = mu_tx + jax.exp(log_sigma_tx) * eps_tx

        # Sample latent variables corresponding to the background in the target dataset.
        eps_zx = jax.random.normal(rng_zx, shape=x.shape[0] + (model.bkg_latent_dim,))
        zx = mu_zx + jax.exp(log_sigma_zx) * eps_zx

        # Sample latent variables corresponding to the background in the background dataset.
        eps_zy = jax.random.normal(rng_zy, shape=y.shape[0] + (model.bkg_latent_dim,))
        zy = mu_zy + jax.exp(log_sigma_zy) * eps_zy
        
        return tx, zx, zy

    def expect_data(
        self, tx: Array, zx: Array, zy: Array, a_mat: Optional[Array] = None
    ) -> Tuple:
        
        w_mat, s_mat, mu_x, mu_y = model_params
        
        # Compute expected values of the target and background datasets.
        x = self.vae_z.decoder(zx, a_mat) + self.vae_t.decoder(tx, a_mat)
        y = self.vae_z.decoder(zy, a_mat)

        if a_mat is not None:
            x = a_mat @ x
            y = a_mat @ y
        return x, y

    @nn.compact
    def __call__(
        self, rng: Array, x: Array, y: Array, a_mat: Optional[Array] = None
    ) -> Tuple: 
        
        # Generate new samples of the data.
        # Sample latent variables.
        tx, zx, zy = self.sample_latents(rng, x, y, a_mat)
        
        # Compute expectation for the data, given the latents.
        x, y = self.expect_data(tx, zx, zy, a_mat)
        return x, y
    

In [11]:
@jax.jit
def update_model(state, grads):
    """Update model with gradients."""
    return state.apply_gradients(grads=grads)

In [12]:
@jax.jit
def apply_model(rng, state, x, y, sigma_noise):
    """Computes gradients and loss for a single batch."""

    # loss
    def loss_fn(params):
        # Draw samples in data space.
        x_draws, y_draws = state.apply_fn(
            {'params': params}, rng, x, y, 
            rngs={'dropout': rng_drop} # Not sure if necessary?
        )

        # Compute loss function
        loss = - (optax.losses.squared_error(x, x_draws) / 2 / sigma_noise**2).mean(axis=-1)
        loss -= (optax.losses.squared_error(y, y_draws) / 2 / sigma_noise**2).mean(axis=-1)
        return loss

    grad_fn = jax.value_and_grad(loss_fn,)
    loss, grads = grad_fn(state.params)

    return grads, loss, loss_dict
