In [1]:
import multiprocessing as mp
import time
import json
import wandb
import os
import numpy as np
import io
import tempfile
import gc
import logging
import psutil
import cv2
import threading
import traceback
import jax.numpy as jnp
import numpy as np
import jax
import argparse
from diffusers import FlaxAutoencoderKL
from diffusers.models.vae_flax import FlaxAutoencoderKLOutput

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
weight_dtype = jnp.float16
vae, vae_params = FlaxAutoencoderKL.from_pretrained(
    "CompVis/stable-diffusion-v1-4",
    revision="flax",
    subfolder="vae",
    dtype=weight_dtype,
)

  deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False)


In [10]:
def encode(img, rng, vae_params):
    init_value = vae.apply({"params": vae_params}, img, method=vae.encode)
    init_latent_dist = init_value.latent_dist
    out_mean = init_latent_dist.mean
    out_std = init_latent_dist.std
    # latent_dist = init_latent_dist.sample(key=rng)
    # latent_dist = latent_dist.transpose((0, 3, 1, 2))
    # latent_dist = vae.config.scaling_factor * latent_dist
    return (out_mean, out_std)

def create_key(seed=0):
    return jax.random.PRNGKey(seed)

rng = create_key(0)
model = jax.pmap(encode, in_axes=(0, None, None))

In [22]:
expected_input_shape = (4, 8, 3, 256, 256)

def prep_batch(batch):
    # batch is (D, B, H, W, C) (uint8) (0-255)
    batch = np.transpose(batch, (0, 1, 4, 2, 3)) # (D, B, H, W, C) -> (D, B, C, H, W)
    # (D, B, C, H, W) (uint8) (0-255)
    assert batch.dtype == np.uint8, f"{batch.dtype} != uint8"
    assert batch.max() <= 255, f"{batch.max()} > 255"
    assert batch.min() >= 0, f"{batch.min()} < 0"
    assert batch.shape == expected_input_shape, f"{batch.shape} != {expected_input_shape}"
    batch = batch.astype(np.float32)
    batch = batch / 255.0 # (D, B, C, H, W) (float32) (0-1)
    # added 7/24/2023
    batch = batch * 2.0 - 1.0
    assert batch.dtype == np.float32, f"{batch.dtype} != float32"
    assert batch.max() <= 1.0, f"{batch.max()} > 1.0"
    assert batch.min() >= -1.0, f"{batch.min()} < -1.0"
    gc.collect()
    return batch

In [19]:
# mimick rgb image as np (0-255), float32, then div to 0-1
img = np.random.randint(0, 255, (4, 8, 256, 256, 3), dtype=np.uint8)
print(img.shape)
print(img.dtype)
print(img.min(), img.max())

(4, 8, 256, 256, 3)
uint8
0 254


In [23]:
img = prep_batch(img)
print(img.shape)
print(img.dtype)
print(img.min(), img.max())

(4, 8, 3, 256, 256)
float32
-1.0 0.99215686


In [29]:
x = model(img, rng, vae_params)
x = x[0].block_until_ready(), x[1].block_until_ready()

In [31]:
for z in x:
    print(z.shape)
    print(z.dtype)
    print(z.min(), z.max())

(4, 8, 32, 32, 4)
float16
-15.3 12.63
(4, 8, 32, 32, 4)
float16
1.71e-05 0.01269


In [33]:
print("Mean:")
print(x[0])
print("##########Separation##########")
print("Std:")
print(x[1])

Mean:
[[[[[-2.4258e+00  1.0137e+00  2.0566e+00 -2.6602e+00]
    [-7.6211e+00  8.8135e-01 -5.5615e-01 -5.3438e+00]
    [-4.0312e+00  7.5703e+00  8.3252e-01 -2.5000e+00]
    ...
    [-4.4023e+00  1.3779e+00 -1.5586e+00 -3.6992e+00]
    [-9.0859e+00 -5.5664e-01 -1.1406e+00 -8.1094e+00]
    [-3.0039e+00  1.2217e+00 -2.5918e+00 -4.0273e+00]]

   [[-7.5586e-01  4.1484e+00 -6.3242e+00 -1.2217e+00]
    [-3.1797e+00  1.2158e+00 -2.8789e+00 -2.9180e+00]
    [-6.4297e+00  1.2764e+00 -3.8652e+00 -5.1406e+00]
    ...
    [-3.1719e+00  2.7051e+00 -2.6406e+00 -1.4893e+00]
    [-4.6758e+00  2.2207e+00 -3.0605e+00 -4.3555e+00]
    [-1.2939e+00  2.6211e+00 -6.3984e+00 -9.8145e-01]]

   [[-5.9023e+00 -3.8452e-01  6.8994e-01 -4.6328e+00]
    [-3.2617e-01  4.8047e+00 -1.2927e-01  4.6478e-02]
    [-2.6719e+00  3.4277e+00 -4.5508e+00 -5.6006e-01]
    ...
    [-7.3945e+00  4.2930e+00 -6.2500e-01 -5.6953e+00]
    [-3.2363e+00  2.6953e+00  2.6094e+00 -9.4922e-01]
    [-6.0820e+00  2.3848e+00  5.6836e+00 -3.0820