In [8]:
import os

from absl import app, flags
import jax
import jax.numpy as jnp
from ml_collections import config_flags
from orbax.checkpoint import CheckpointManager, CheckpointManagerOptions
from orbax.checkpoint import PyTreeCheckpointer
import optax
from tqdm import tqdm
import wandb

from ddprism import utils
from ddprism.pcpca import pcpca_utils
from ddprism import plotting_utils

from ddprism.pcpca import pcpca_utils

from ddprism.corrupted_mnist import datasets
from ddprism.corrupted_mnist import metrics
from ddprism.corrupted_mnist import config_base_grass, config_base_mnist
from ddprism.corrupted_mnist import config_pcpca as config_pcpca_



In [9]:
import importlib
importlib.reload(pcpca_utils)

<module 'ddprism.pcpca.pcpca_utils' from '/mnt/home/aakhmetzhanova/ddprism/ddprism/pcpca/pcpca_utils.py'>

In [10]:
# Path to Imagenet dataset
imagenet_path = '/mnt/home/aakhmetzhanova/ceph/galaxy-diffusion/corrupted-mnist/dataset/grass_jpeg/' 


In [11]:
def run_pcpca():
    config_pcpca = config_pcpca_.get_config() 
    
    # Generate training datasets.
    # Target dataset with corrupted mnist digits.
    config_mnist = config_base_mnist.get_config()
    rng = jax.random.key(config_mnist.rng_key)
    rng_dataset, rng_comp, rng = jax.random.split(rng, 3)
    
    f_train = datasets.get_corrupted_mnist(
        rng_dataset, grass_amp=1., mnist_amp=config_mnist.mnist_amp,
        imagenet_path=imagenet_path,
        dataset_size=config_mnist.dataset_size,
        zeros_and_ones=True
    )
    # Target dataset with uncorrupted mnist digits for computing metrics later on.
    f_train_uncorrupted = datasets.get_corrupted_mnist(
        rng_dataset, grass_amp=0., mnist_amp=1.,
        imagenet_path=imagenet_path,
        dataset_size=config_mnist.dataset_size,
        zeros_and_ones=True
    )
    # Background dataset with grass only.
    config = config_base_grass.get_config()
    rng = jax.random.key(config.rng_key)
    rng_dataset, rng_comp, rng = jax.random.split(rng, 3)
    b_train = datasets.get_corrupted_mnist(
        rng_dataset, grass_amp=1., mnist_amp=0.,
        imagenet_path=imagenet_path, 
        dataset_size=config.dataset_size,
        zeros_and_ones=True)

    # Generate validation datasets with corrupted mnist digits.
    rng = jax.random.key(config_mnist.rng_key_val)
    rng_dataset, rng_comp, rng = jax.random.split(rng, 3)
    
    # Validation target dataset with corrupted mnist digits.
    f_val = datasets.get_corrupted_mnist(
        rng_dataset, grass_amp=1., mnist_amp=config_mnist.mnist_amp,
        imagenet_path=imagenet_path,
        dataset_size=14720, 
        zeros_and_ones=True
    )
    # Take previously unseen digits for the validation set.
    f_val_images = f_val[0][config_mnist.dataset_size:]
    f_val_labels = f_val[1][config_mnist.dataset_size:]
    f_val = (f_val_images, f_val_labels)
    
    # Validation dataset with uncorrupted mnist digits for computing metrics later on.
    f_val_uncorrupted = datasets.get_corrupted_mnist(
        rng_dataset, grass_amp=0., mnist_amp=1.,
        imagenet_path=imagenet_path,
        dataset_size=14720, 
        zeros_and_ones=True
    )
    # Take previously unseen digits for the validation set.
    f_val_images = f_val_uncorrupted[0][config_mnist.dataset_size:]
    f_val_labels = f_val_uncorrupted[1][config_mnist.dataset_size:]
    f_val_uncorrupted = (f_val_images, f_val_labels)

    config.batch_size = 128 
    
    # Set regularization parameter for numerical stability.
    regularization = getattr(config, 'regularization', 1e-6)

    # PCPCA analysis.
    feat_dim = 784
    # In PCPCA language, y_enr is enriched observation and y_bkg is background.
    y_enr, y_enr_labels = f_train
    y_enr = y_enr.squeeze(-1).reshape(-1, feat_dim)
    y_bkg, y_bkg_labels = b_train
    y_bkg = y_bkg.squeeze(-1).reshape(-1, feat_dim)

    # Use batches to avoid running into memory issues
    enr_a_mat = jnp.repeat(jnp.eye(feat_dim, feat_dim)[None, ...], config.batch_size, axis=0).copy()
    bkg_a_mat = enr_a_mat.copy()

    metrics = {}
    # Initialize W and log_sigma using PCA of the pseudo-inverse of the
    # observation matrix.
    rng_w, rng = jax.random.split(rng, 2)
    # since enr_a_mat is an identity matrix, x_pinv = y_enr
    cov_empirical = jnp.cov(y_enr, rowvar=False)
    u_mat, s_mat, _ = jnp.linalg.svd(cov_empirical)
    weights_init = u_mat[:, :config_pcpca.latent_dim] * jnp.sqrt(s_mat[:config_pcpca.latent_dim])
    weights_init += 0.01 * jax.random.normal(
            rng_w, shape=(feat_dim, config_pcpca.latent_dim)
        )

    log_sigma_init = jnp.log(config.sigma_y)
    mu_x_init, mu_y_init = jnp.mean(y_enr, axis=0), jnp.mean(y_bkg, axis=0), 
    params = {
            'weights': jnp.asarray(weights_init), 'log_sigma': log_sigma_init,
            'mu_x': mu_x_init, 'mu_y': mu_y_init
        }
    # Optimization loop parameters.
    if config_pcpca.lr_schedule == 'linear':
        schedule = optax.schedules.linear_schedule(
            config_pcpca.learning_rate, 1e-6, config_pcpca.n_iter
        )
    elif config_pcpca.lr_schedule == 'cosine':
        schedule = optax.schedules.cosine_decay_schedule(
            init_value=config_pcpca.learning_rate, decay_steps=config_pcpca.n_iter
        )
    else:
        raise ValueError(
            f'Unknown learning rate schedule: {config_pcpca.lr_schedule}'
        )

    # Initialize Adam optimizer
    optimizer = optax.adam(learning_rate=schedule)
    opt_state = optimizer.init(params)
    
    # Run the optimization loop.
    pbar = tqdm(range(config_pcpca.n_iter))
    losses = []
    for step in pbar:
        loss_per_step = []
        pbar_step = tqdm(range(config.dataset_size // config.batch_size))
        for _ in pbar_step:
            rng_enr, rng_bkg, rng = jax.random.split(rng, 3)
            
            batch_idx = jax.random.randint(rng_enr, shape=(config.batch_size,), minval=0, maxval=config_mnist.dataset_size)
            y_enr_batch = y_enr[batch_idx] 

            batch_idx = jax.random.randint(rng_bkg, shape=(config.batch_size,), minval=0, maxval=config.dataset_size)
            y_bkg_batch = y_bkg[batch_idx] 
    
    
            grad = jax.jit(pcpca_utils.loss_grad)(
                    params, y_enr_batch, y_bkg_batch, enr_a_mat, bkg_a_mat, config_pcpca.gamma,
                    regularization
            )
            loss = jax.jit(pcpca_utils.loss)(
                    params, y_enr_batch, y_bkg_batch, enr_a_mat, bkg_a_mat, config_pcpca.gamma,
                    regularization
            )
            loss_per_step.append(loss)
            

            # Update parameters
            updates, opt_state = optimizer.update(grad, opt_state, params)
            params = optax.apply_updates(params, updates)
            params['log_sigma'] = jnp.log(config.sigma_y) # Fix log_sigma.
            pbar_step.set_postfix({'loss_per_step': f'{loss:.6f}'})

        # Log our loss.
        loss_per_step = jnp.asarray(loss_per_step).mean()
        pbar.set_postfix({'loss': f'{loss_per_step:.6f}'})
        losses.append(loss_per_step)

    print(params)
    return params, losses



In [12]:
params, losses = run_pcpca()

  0%|          | 0/100 [00:00<?, ?it/s]
  0%|          | 0/256 [00:00<?, ?it/s][A
  0%|          | 0/256 [00:49<?, ?it/s, loss_per_step=nan][A
  0%|          | 1/256 [00:49<3:28:55, 49.16s/it, loss_per_step=nan][A
  0%|          | 1/256 [01:33<3:28:55, 49.16s/it, loss_per_step=nan][A
  1%|          | 2/256 [02:17<4:50:25, 68.60s/it, loss_per_step=nan][A
  0%|          | 0/100 [02:17<?, ?it/s]


KeyboardInterrupt: 