In [None]:
from diffusers import AutoencoderKL
import torch
import matplotlib.pyplot as plt
import random
from pathlib import Path
from PIL import Image
import torchvision.transforms.functional as TF

In [None]:
model_id = "runwayml/stable-diffusion-v1-5"

In [None]:
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae").to("cuda:0").half()
scaling_factor = vae.config.scaling_factor

In [None]:
emb_path = Path("out/latents_01.pt")
embs = torch.load(emb_path, map_location="cuda:0")

In [None]:
images_synth = []

with torch.inference_mode():
    for i, emb_list in enumerate(embs):
        if i % 20 == 0:
            print(i)
        emb_synth, cls = emb_list
        emb_synth = emb_synth.to("cuda:0")
        latents_synth = (1 / scaling_factor) * emb_synth
        latents_synth = latents_synth.half()
        image_synth = vae.decode(latents_synth).sample
        images_synth.append((image_synth, cls))

In [None]:
torch.save(images_synth, "out/images_01.pt")

In [None]:
def visualize_images(images, figsize=(16,16)):
    # Assuming you have a list of image tensors named 'image_tensors'
    num_images = len(images)
    num_rows = int(num_images ** 0.5)
    num_cols = int(num_images / num_rows)

    fig, axes = plt.subplots(num_rows, num_cols, figsize=figsize, tight_layout=True)


    for i, ax in enumerate(axes.flat):
        image = images[i].float().cpu().numpy().transpose(1, 2, 0)
        # image = image * 0.5 + 0.5
        # image = image.clip(0, 1)
        ax.imshow(image)
        ax.axis('off')

    plt.show()

In [None]:
img_tensors_synth = [i[0][0].squeeze() for i in images_synth]
print(img_tensors_synth[0].shape)
visualize_images(random.sample(img_tensors_synth, 50), figsize=(16,16))

In [None]:
data_path = Path("ImageNet10/train/n01440764")
image_paths = list(data_path.glob("*.JPEG"))
real_images = [TF.to_tensor(Image.open(img_path)) * 2. - 1. for img_path in image_paths]
visualize_images(random.sample(real_images, 50), figsize=(16,16))