In [None]:
import sys
from pathlib import Path
sys.path.append(str(Path.cwd().parent.parent))

In [None]:
from katacv.utils.related_pkgs.jax_flax_optax_orbax import *
from katacv.utils.related_pkgs.utility import *
from katacv.G_VAE.parser import get_args_and_writer
args = get_args_and_writer(no_writer=True, input_args=[], model_name='G-VAE', dataset_name='celeba')
from katacv.G_VAE.model import get_g_vae_model_state, get_decoder_state
g_vae_state = get_g_vae_model_state(args)
g_vae_decoder_state = get_decoder_state(args)
print("Successfully initialze model state!")

In [None]:
path_weights = "/home/wty/Coding/models/G-VAE/celeba/G-VAE2048-0010-lite"
weights = ocp.PyTreeCheckpointer().restore(path_weights)
g_vae_state = g_vae_state.replace(params=weights['params'], batch_stats=weights['batch_stats'])
g_vae_decoder_state = g_vae_decoder_state.replace(params=weights['params']['Decoder_0'], batch_stats=weights['batch_stats']['Decoder_0'])
print("Successfully load model weights.")

In [None]:
from katacv.utils.celeba.build_dataset import DatasetBuilder
args.path_dataset = Path("/home/wty/Coding/datasets/celeba/")
ds_builer = DatasetBuilder(args)
ds_train, ds_train_size = ds_builer.get_dataset(subset='train', repeat=1, shuffle=False, use_aug=False)
ds_val, ds_val_size = ds_builer.get_dataset(subset='val', repeat=1, shuffle=False, use_aug=False)
print("Succesfully build dataset!")

In [None]:
@jax.jit
def predict(state, x):
    return state.apply_fn(
        {'params': state.params, 'batch_stats': state.batch_stats},
        x, train=False
    )

def decoder_predict(decoder_state, z):
    aug = jax.device_get(predict(decoder_state, z))
    aug = (aug - aug.min()) / (aug.max() - aug.min())
    return aug

predict(g_vae_state, jnp.empty(args.input_shape, dtype=jnp.float32))
decoder_predict(g_vae_decoder_state, jnp.empty((args.batch_size, args.feature_size), dtype=jnp.float32))
print("Complete compiled!")

In [None]:
# Get average sigma
from tqdm import tqdm
import numpy as np

sigmas = [np.zeros(args.feature_size) for _ in range(args.class_num)]
bar = tqdm(ds_train, total=ds_train_size)
for i, (x, y) in enumerate(bar):
    x, y = x.numpy(), y.numpy()
    distrib, _, _ = predict(g_vae_state, x)
    mu, sigma = distrib
    for i in range(args.class_num):
        sigmas[i] += (sigma[y==i].mean(0) - sigmas[i]) / (i + 1)
    break

In [None]:
from PIL import Image
import matplotlib.pyplot as plt
def show_image_aug(x, n=14, name='image_aug', threshold_rate=0.05):
    distrib, _, _ = jax.device_get(predict(g_vae_state, x))
    mu, logsigma2 = distrib
    sigma = np.sqrt(np.exp(logsigma2))
    threshold_idx = int(sigma.shape[1]*(1-threshold_rate))
    threshold = np.sort(sigma, axis=-1)[:, threshold_idx:threshold_idx+1]
    delta = np.where(sigma >= threshold, sigma, 0)
    plt.hist(sigma[0], bins=50)
    plt.hist(delta[0][delta[0] > 0], bins=50)
    print(mu.mean(), sigma.mean())
    pos, neg = [], []
    for i in range(n//2):
        z = mu - i * delta * 0.5
        aug = decoder_predict(g_vae_decoder_state, z)
        neg.append(aug)

        z = mu + i * delta * 0.5
        aug = decoder_predict(g_vae_decoder_state, z)
        pos.append(aug)
    image = x
    if image.shape[-1] == 1:  # gray origin image invert colors
        image = 1 - image  # mid: (B,N,N,1)
    for aug in neg: image = np.concatenate([aug, image], axis=2)
    for aug in pos: image = np.concatenate([image, aug], axis=2)
    # add a blank
    image = np.concatenate([image, np.zeros((image.shape[0], image.shape[1], 5, 3))], axis=2)
    # Gauss augmentatiton
    np.random.seed(42)
    z = mu + np.random.randn(*mu.shape)
    aug = decoder_predict(g_vae_decoder_state, z)
    image = np.concatenate([image, aug], axis=2)
    image = image.reshape((-1, *image.shape[-2:]))
    
    if image.shape[-1] == 1:
        image = image[..., 0]
    image = (image*255).astype('uint8')
    image = Image.fromarray(image)
    image.save(str(args.path_figures.joinpath(name+'.jpg')))
    image.show()

args.path_figures = args.path_logs.joinpath("figures")
args.path_figures.mkdir(exist_ok=True)
for i, (x, y) in enumerate(bar):
    x, y = x.numpy(), y.numpy()
    show_image_aug(x[:10])
    break