In [None]:
import sys
from pathlib import Path
sys.path.append(str(Path.cwd().parent.parent))

In [None]:
from katacv.utils.related_pkgs.jax_flax_optax_orbax import *
from katacv.utils.related_pkgs.utility import *
from katacv.G_VAE.parser import get_args_and_writer
args = get_args_and_writer(no_writer=True, input_args=[], model_name='G-VAE', dataset_name='celeba')
from katacv.G_VAE.model import get_g_vae_model_state, get_decoder_state
g_vae_state = get_g_vae_model_state(args)
g_vae_decoder_state = get_decoder_state(args)
print("Successfully initialze model state!")

In [None]:
path_weights = "/home/yy/Coding/models/G-VAE/celeba/G-VAE2048-0010-lite"
weights = ocp.PyTreeCheckpointer().restore(path_weights)
g_vae_state = g_vae_state.replace(params=weights['params'], batch_stats=weights['batch_stats'])
g_vae_decoder_state = g_vae_decoder_state.replace(params=weights['params']['Decoder_0'], batch_stats=weights['batch_stats']['Decoder_0'])
print("Successfully load model weights.")

In [None]:
from katacv.utils.celeba.build_dataset import DatasetBuilder
args.path_dataset = Path("/home/yy/Coding/datasets/celeba/")
args.batch_size = 25
ds_builer = DatasetBuilder(args)
ds_train, ds_train_size = ds_builer.get_dataset(subset='train', repeat=1, shuffle=False, use_aug=False)
ds_val, ds_val_size = ds_builer.get_dataset(subset='val', repeat=1, shuffle=False, use_aug=False)
print("Succesfully build dataset!")

In [None]:
@jax.jit
def predict(state, x):
    return state.apply_fn(
        {'params': state.params, 'batch_stats': state.batch_stats},
        x, train=False
    )

def decoder_predict(decoder_state, z):
    aug = jax.device_get(predict(decoder_state, z)
    aug = (aug - aug.min()) / (aug.max() - aug.min())
    return aug

predict(g_vae_state, jnp.empty(args.input_shape, dtype=jnp.float32))
decoder_predict(g_vae_decoder_state, jnp.empty((args.batch_size, args.feature_size), dtype=jnp.float32))
print("Complete compiled!")

In [None]:
# Get average sigma
from tqdm import tqdm
import numpy as np

sigmas = np.zeros((args.class_num, args.feature_size))
bar = tqdm(ds_train, total=ds_train_size)
for i, (x, y) in enumerate(bar):
    x, y = x.numpy(), y.numpy()
    mu, logsigma2 = jax.device_get(predict(g_vae_state, x)[0])
    sigma = np.sqrt(np.exp(logsigma2))
    for j in range(args.class_num):
        if (y==j).sum() == 0: continue
        sigmas[j] += (sigma[y==j].mean(0) - sigmas[j]) / (i + 1)

In [None]:
from katacv.utils.celeba.label2readable import label2readable
for key, value in label2readable.items():
    print(value, sigmas[key].mean())

In [None]:
import matplotlib.pyplot as plt
xs = range(n)
for key, value in label2readable.items():
    plt.scatter(xs, sigmas[key], label=value)
plt.legend()

In [None]:
threshold_rate = 0.05
threshold_idx = int(sigmas.shape[1]*(1-threshold_rate))
threshold = np.sort(sigmas, axis=-1)[:, threshold_idx:threshold_idx+1]
deltas = np.where(sigmas >= threshold, sigmas, 0)

for key, value in label2readable.items():
    plt.scatter(xs, deltas[key], label=value)
plt.legend()

In [None]:
target_idx = 16
target_image = None
for x, y in ds_val:
    x, y = x.numpy(), y.numpy()
    if target_idx < 10:
        target_image = x[target_idx]
        break
    target_idx -= 10
# plt.imshow(target_image)
z, _ = jax.device_get(predict(g_vae_state, target_image[None,...])[0])
print(z.shape)

In [None]:
aug = decoder_predict(g_vae_decoder_state, z)
plt.subplot(121)
plt.imshow(target_image)
plt.subplot(122)
plt.imshow(aug[0])

In [None]:
label2readable

In [None]:
print(target_image.shape)

In [None]:
r, c = 3, 4
alpha_x, alpha_y = 3 / c, 3 / r
image = []
for i in range(r*2+1):
    row = []
    for j in range(c*2+1):
        delta_x = -deltas[0] if j > c else (deltas[2] if j < c else 0)
        delta_y = -deltas[1] if i < r else (deltas[3] if i > r else 0)
        p = np.abs(np.array((i,j)) - np.array((r,c)))
        delta = alpha_x * p[1] * delta_x + alpha_y * p[0] * delta_y
        aug = decoder_predict(g_vae_decoder_state, z + delta)[0]
        row.append(aug)
    row = np.concatenate(row, axis=1)
    image.append(row)
image = np.concatenate(image, axis=0)
image = (image*255).astype('uint8')
plt.figure(figsize=(10, 15))
plt.imshow(image)
# print(image.shape)
from PIL import Image
Image.fromarray(image).save(str(args.path_logs.joinpath("change_image.jpg")))

In [None]:
total = 10000
zs = [[] for _ in range(args.class_num)]
bar = tqdm(ds_train, total=ds_train_size)
for i, (x, y) in enumerate(bar):
    x, y = x.numpy(), y.numpy()
    mu, _ = jax.device_get(predict(g_vae_state, x)[0])
    for j in range(mu.shape[0]):
        zs[y[j]].append(mu[j:j+1])
for i in range(args.class_num):
    print(label2readable[i], len(zs[i]))
    zs[i] = np.concatenate(zs[i], axis=0)

In [None]:
n = 1000
colors = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red']
fig, axs = plt.subplots(1, 3, figsize=(12, 4))
for i, z in enumerate(zs):
    X = z.T @ z
    # print(X.shape)
    eval, evec = np.linalg.eig(X)
    # print(eval.shape, evec.shape)
    vec = evec[:2]
    z_down = z[:n] @ vec.T
    # print(z_down.shape)
    if i in [0,1]:
        axs[0].scatter(z_down[:,0], z_down[:,1], label=label2readable[i], c=colors[i])
    else:
        axs[1].scatter(z_down[:,0], z_down[:,1], label=label2readable[i], c=colors[i])
    axs[2].scatter(z_down[:,0], z_down[:,1], label=label2readable[i], c=colors[i])
for ax in axs:
    ax.set_xticks([])
    ax.set_yticks([])
    ax.legend(loc='lower left')
plt.tight_layout()
plt.savefig(str(args.path_logs.joinpath("pca.jpg")), dpi=300)
# plt.legend()

In [None]:
from PIL import Image
import matplotlib.pyplot as plt
def show_image_aug(x, n=14, name='image_aug', threshold_rate=0.05):
    distrib, _, _ = jax.device_get(predict(g_vae_state, x))
    mu, logsigma2 = distrib
    sigma = np.sqrt(np.exp(logsigma2))
    threshold_idx = int(sigma.shape[1]*(1-threshold_rate))
    threshold = np.sort(sigma, axis=-1)[:, threshold_idx:threshold_idx+1]
    delta = np.where(sigma >= threshold, sigma, 0)
    plt.hist(sigma[0], bins=50)
    plt.hist(delta[0][delta[0] > 0], bins=50)
    print(mu.mean(), sigma.mean())
    pos, neg = [], []
    for i in range(n//2):
        z = mu - i * delta * 0.5
        aug = decoder_predict(g_vae_decoder_state, z)
        neg.append(aug)

        z = mu + i * delta * 0.5
        aug = decoder_predict(g_vae_decoder_state, z)
        pos.append(aug)
    image = x
    if image.shape[-1] == 1:  # gray origin image invert colors
        image = 1 - image  # mid: (B,N,N,1)
    for aug in neg: image = np.concatenate([aug, image], axis=2)
    for aug in pos: image = np.concatenate([image, aug], axis=2)
    # add a blank
    image = np.concatenate([image, np.zeros((image.shape[0], image.shape[1], 5, 3))], axis=2)
    # Gauss augmentatiton
    np.random.seed(42)
    z = mu + np.random.randn(*mu.shape)
    aug = decoder_predict(g_vae_decoder_state, z)
    image = np.concatenate([image, aug], axis=2)
    image = image.reshape((-1, *image.shape[-2:]))
    
    if image.shape[-1] == 1:
        image = image[..., 0]
    image = (image*255).astype('uint8')
    image = Image.fromarray(image)
    image.save(str(args.path_figures.joinpath(name+'.jpg')))
    image.show()

args.path_figures = args.path_logs.joinpath("figures")
args.path_figures.mkdir(exist_ok=True)
for i, (x, y) in enumerate(bar):
    x, y = x.numpy(), y.numpy()
    show_image_aug(x[:10])
    break