In [None]:
%cd ..
%env CUDA_VISIBLE_DEVICES=1

In [None]:
import torch
import random
from utils import load_encoders
from models.sit import SiT_models
from diffusers import AutoencoderKL
import json
import numpy as np
import h5py
import PIL
import io
from torchvision.transforms import Normalize
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
import matplotlib.pyplot as plt

pyspng = None

In [None]:
CLIP_DEFAULT_MEAN = (0.48145466, 0.4578275, 0.40821073)
CLIP_DEFAULT_STD = (0.26862954, 0.26130258, 0.27577711)

### Load encoders
encoders, encoder_types, architectures = load_encoders("dinov2-vit-b", "cuda:0")

### Model HPs
mode = "mse"
model = "SiT-B/2"
# ckpt_path = "ckpts/sit-b-base-400k-last.pt"  # <-- Change the checkpoint file here
ckpt_path = "ckpts/sit-b-linear-dinov2-b-enc8-400k-last.pt"
resolution = 256
num_classes = 1000
assert resolution % 8 == 0
latent_size = resolution // 8
z_dims = [encoder.embed_dim for encoder in encoders]
encoder_depth = 8
block_kwargs = {"fused_attn": False, "qk_norm": False}

model = SiT_models[model](
    latent_size=latent_size,
    num_classes=num_classes,
    use_cfg=True,
    z_dims=z_dims,
    encoder_depth=encoder_depth,
    **block_kwargs,
).to("cuda:0")
state_dict = torch.load(ckpt_path, map_location="cuda:0")


### Load weights
if mode == "mse":
    model.load_state_dict(state_dict["model"])
elif mode == "ema":
    model.load_state_dict(state_dict["ema"])
model.eval()

### Load VAE
vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{mode}").to("cuda:0")

In [None]:
### Prepare the data
with open("data/images_h5.json", "r") as f:
    images_h5_cfg = json.load(f)
with open("data/vae-sd_h5.json", "r") as f:
    vae_h5_cfg = json.load(f)

def load_h5_file(hf, path):
    # Helper function to load files from h5 file
    if path.endswith('.png'):
        if pyspng is not None:
            rtn = pyspng.load(io.BytesIO(np.array(hf[path])))
        else:
            rtn = np.array(PIL.Image.open(io.BytesIO(np.array(hf[path]))))
        rtn = rtn.reshape(*rtn.shape[:2], -1).transpose(2, 0, 1)
    elif path.endswith('.json'):
        rtn = json.loads(np.array(hf[path]).tobytes().decode('utf-8'))
    elif path.endswith('.npy'):
        rtn= np.array(hf[path])
    else:
        raise ValueError('Unknown file type: {}'.format(path))
    return rtn

def preprocess_raw_image(x, enc_type):
    if 'clip' in enc_type:
        x = x / 255.
        x = torch.nn.functional.interpolate(x, 224, mode='bicubic')
        x = Normalize(CLIP_DEFAULT_MEAN, CLIP_DEFAULT_STD)(x)
    elif 'mocov3' in enc_type or 'mae' in enc_type:
        x = x / 255.
        x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
    elif 'dinov2' in enc_type:
        x = x / 255.
        x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
        x = torch.nn.functional.interpolate(x, 224, mode='bicubic')
    elif 'dinov1' in enc_type:
        x = x / 255.
        x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
    elif 'jepa' in enc_type:
        x = x / 255.
        x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
        x = torch.nn.functional.interpolate(x, 224, mode='bicubic')

    return x

N = 4
chosen_files = random.Random(42).sample(images_h5_cfg, N)
chosen_vaes = [elem.replace("img", "img-mean-std-").replace(".png", ".npy") for elem in chosen_files]
# print(chosen_files)
# print(chosen_vaes)

image_h5 = h5py.File("data/images.h5", "r")
vae_h5 = h5py.File("data/vae-sd.h5", "r")

### Labels...
fname = 'dataset.json'
labels = load_h5_file(vae_h5, fname)['labels']
labels = dict(labels)
labels = [labels[fname.replace('\\', '/')] for fname in chosen_vaes]
labels = np.array(labels)
labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim])

images = preprocess_raw_image(torch.stack([torch.from_numpy(load_h5_file(image_h5, elem)) for elem in chosen_files]), "dinov2-vit-b").to("cuda:0")
vaes = torch.stack([torch.from_numpy(load_h5_file(vae_h5, elem)) for elem in chosen_vaes]).to("cuda:0")
labels = torch.from_numpy(labels).to("cuda:0")

In [None]:
# plot three images, 1. raw image, 2. reconstructed ground-truth latent

@torch.no_grad()
def sample_posterior(moments, latents_scale=1., latents_bias=0.):
    device = moments.device
    
    mean, std = torch.chunk(moments, 2, dim=1)
    z = mean + std * torch.randn_like(mean)
    z = (z * latents_scale + latents_bias) 
    return z 

im_idx = 1

# Visualize the images
def denormalize(x, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD):
    return x * torch.tensor(std).view(3, 1, 1) + torch.tensor(mean).view(3, 1, 1)

# Visualize the raw image
original_image = denormalize(images[im_idx].cpu()).permute(1, 2, 0).numpy()

# Visualize the ground-truth latent
# we first decode the latent
with torch.no_grad():
    decoded_image = vae.decode(sample_posterior(vaes)[im_idx].unsqueeze(0)).sample
decoded_image = (decoded_image + 1.0) / 2.0
decoded_image = decoded_image.clamp(0.0, 1.0).squeeze(0).cpu()

fig, ax = plt.subplots(1, 2, figsize=(8, 4))
ax[0].imshow(original_image)
ax[0].set_title("Original Image")
ax[0].axis("off")
ax[1].imshow(decoded_image.permute(1, 2, 0).numpy())
ax[1].set_title("Reconstructed Ground-Truth Latent")
ax[1].axis("off")
plt.tight_layout()
plt.show()

In [None]:
# Get the DINOv2 feature
zs = []
for encoder, encoder_type, arch in zip(encoders, encoder_types, architectures):
    z = encoder.forward_features(images)['x_norm_patchtokens']
    zs.append(z)

print(zs[0].shape)

In [None]:
### See the diffusion loss part...

def inpterpolant(t):
    alpha_t = 1 - t
    sigma_t = t
    d_alpha_t = -1
    d_sigma_t = 1
    return alpha_t, sigma_t, d_alpha_t, d_sigma_t

latents_scale = torch.tensor(
    [0.18215, 0.18215, 0.18215, 0.18215]
    ).view(1, 4, 1, 1).to("cuda:0")
latents_bias = torch.tensor(
    [0., 0., 0., 0.]
    ).view(1, 4, 1, 1).to("cuda:0")

x = sample_posterior(vaes, latents_scale=latents_scale, latents_bias=latents_bias)
print(f"vae: {vaes.shape}, x: {x.shape}")

model_kwargs = dict(y = labels.to("cuda:0"))
time_input = torch.rand((x.shape[0], 1, 1, 1), device="cuda:0", dtype=x.dtype)
print(time_input.shape)

noises = torch.randn_like(x)
# Linear interpolation
alpha_t, sigma_t, d_alpha_t, d_sigma_t = inpterpolant(time_input)

# Get the noisy latent
model_input = alpha_t * x + sigma_t * noises
# Get the denosing target
model_target = d_alpha_t * x + d_sigma_t * noises

with torch.no_grad():
    model_output, zs_tilde = model(model_input, time_input.flatten(), **model_kwargs)

In [None]:
# We compute the loss, and compare the patch-patch similarity...

def mean_flat(x):
    """
    Take the mean over all non-batch dimensions.
    """
    return torch.mean(x, dim=list(range(1, len(x.size()))))

denoising_loss = mean_flat((model_output - model_target) ** 2)
print(f"Loss: {denoising_loss.cpu().detach().numpy()}")


proj_loss = []
bsz = zs[0].shape[0]
for i, (z, z_tilde) in enumerate(zip(zs, zs_tilde)):
    inner_proj_loss = []
    for j, (z_j, z_tilde_j) in enumerate(zip(z, z_tilde)):
        z_tilde_j = torch.nn.functional.normalize(z_tilde_j, dim=-1) 
        z_j = torch.nn.functional.normalize(z_j, dim=-1) 
        inner_proj_loss.append(mean_flat((z_j * z_tilde_j).sum(dim=-1)).item())
    proj_loss.append(inner_proj_loss)
proj_loss = np.array(proj_loss).mean(axis=0)  # (column-wise) mean, get the cos-sim for each data point
print(f"Patch-Patch cos-sim: {proj_loss}")


-----------------------------------------