In [None]:
%cd ..
%env CUDA_VISIBLE_DEVICES=1

In [None]:
import torch
import random
from utils import load_encoders
from models.sit import SiT_models
from diffusers import AutoencoderKL
import torch.nn.functional as F
import json
import numpy as np
import h5py
import os
import PIL
import gc
import io
from torchvision.transforms import Normalize
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
import matplotlib.pyplot as plt
from tqdm import tqdm

pyspng = None

In [None]:
### Prepare the data
with open("data/images_h5.json", "r") as f:
    images_h5_cfg = json.load(f)
with open("data/vae-sd_h5.json", "r") as f:
    vae_h5_cfg = json.load(f)

def load_h5_file(hf, path):
    # Helper function to load files from h5 file
    if path.endswith('.png'):
        if pyspng is not None:
            rtn = pyspng.load(io.BytesIO(np.array(hf[path])))
        else:
            rtn = np.array(PIL.Image.open(io.BytesIO(np.array(hf[path]))))
        rtn = rtn.reshape(*rtn.shape[:2], -1).transpose(2, 0, 1)
    elif path.endswith('.json'):
        rtn = json.loads(np.array(hf[path]).tobytes().decode('utf-8'))
    elif path.endswith('.npy'):
        rtn= np.array(hf[path])
    else:
        raise ValueError('Unknown file type: {}'.format(path))
    return rtn

def preprocess_raw_image(x, enc_type):
    x = x / 255.
    x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
    x = torch.nn.functional.interpolate(x, 224, mode='bicubic')
    x = torch.nn.functional.interpolate(x, 224, mode='bicubic')
    return x

N = 1024
BS = 8
chosen_files = random.Random(42).sample(images_h5_cfg, N)
chosen_vaes = [elem.replace("img", "img-mean-std-").replace(".png", ".npy") for elem in chosen_files]

image_h5 = h5py.File("data/images.h5", "r")
vae_h5 = h5py.File("data/vae-sd.h5", "r")

### Labels...
fname = 'dataset.json'
labels = load_h5_file(vae_h5, fname)['labels']
labels = dict(labels)
labels = [labels[fname.replace('\\', '/')] for fname in chosen_vaes]
labels = np.array(labels)
labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim])

images = preprocess_raw_image(torch.stack([torch.from_numpy(load_h5_file(image_h5, elem)) for elem in chosen_files]), "dinov2-vit-b")
vaes = torch.stack([torch.from_numpy(load_h5_file(vae_h5, elem)) for elem in chosen_vaes])
labels = torch.from_numpy(labels)

In [None]:
### Load other models
encoders, encoder_types, architectures = load_encoders("dinov2-vit-b", "cuda:0")
encoder, encoder_type, architecture = encoders[0], encoder_types[0], architectures[0]
# vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to("cuda:0")

### Model HPs
model_name = "SiT-B/2"
resolution = 256
num_classes = 1000
assert resolution % 8 == 0
latent_size = resolution // 8
z_dims = [encoder.embed_dim for encoder in encoders]
encoder_depth = 8
block_kwargs = {"fused_attn": False, "qk_norm": False}


def get_model(ckpt_path):
    model = SiT_models[model_name](
        latent_size=latent_size,
        num_classes=num_classes,
        use_cfg=True,
        z_dims=z_dims,
        encoder_depth=encoder_depth,
        **block_kwargs,
    ).to("cuda:0")
    state_dict = torch.load(ckpt_path, map_location="cuda:0")
    model.load_state_dict(state_dict["model"])
    model.eval()
    return model

def interpolant(t):
    alpha_t = 1 - t
    sigma_t = t
    d_alpha_t = -1
    d_sigma_t = 1
    return alpha_t, sigma_t, d_alpha_t, d_sigma_t

def mean_flat(x):
    return torch.mean(x, dim=list(range(1, len(x.size()))))

@torch.no_grad()
def sample_posterior(moments, latents_scale=1., latents_bias=0.):
    mean, std = torch.chunk(moments, 2, dim=1)
    z = mean + std * torch.randn_like(mean)
    z = (z * latents_scale + latents_bias) 
    return z

latents_scale = torch.tensor(
    [0.18215, 0.18215, 0.18215, 0.18215]
).view(1, 4, 1, 1).to("cuda:0")
latents_bias = torch.tensor(
    [0., 0., 0., 0.]
).view(1, 4, 1, 1).to("cuda:0")

In [None]:
ckpt_paths = ["exps/sit-b-base-400k/checkpoints", "exps/sit-b-linear-dinov2-b-enc8-400k/checkpoints"]

zs = []
for i in tqdm(range(0, N, BS)):
    images_batch = preprocess_raw_image(images[i:i+BS].to("cuda:0"), encoder_type)
    with torch.no_grad():
        z = encoder.forward_features(images_batch)['x_norm_patchtokens']
        zs.append(z)
zs = torch.cat(zs, dim=0).cpu()


In [None]:
ckpt_stats = {}

for ckpt_path in ckpt_paths:
    stats_neg_cos_sim = {}
    stats_frobenius_l2_norm = {}

    for elem in sorted(os.listdir(ckpt_path)):
        model = get_model(os.path.join(ckpt_path, elem))

        # Call gc
        gc.collect()
        torch.cuda.empty_cache()

        zs_tilde = []
        for i in tqdm(range(0, N, BS)):
            vaes_batch = vaes[i:i+BS].to("cuda:0")
            labels_batch = labels[i:i+BS].to("cuda:0")
            x = sample_posterior(vaes_batch, latents_scale=latents_scale, latents_bias=latents_bias)
            model_kwargs = dict(y = labels_batch)
            time_input = torch.rand((x.shape[0], 1, 1, 1), device="cuda:0", dtype=x.dtype)

            noises = torch.randn_like(x)
            alpha_t, sigma_t, d_alpha_t, d_sigma_t = interpolant(time_input)  # linear

            model_input = alpha_t * x + sigma_t * noises
            model_target = d_alpha_t * x + d_sigma_t * noises

            with torch.no_grad():
                model_output, z_tilde = model(model_input, time_input.flatten(), **model_kwargs)
                zs_tilde.append(z_tilde[0])
        zs_tilde = torch.cat(zs_tilde, dim=0).cpu()
        
        # Compare current zs_tilde with zs
        # Part 1: cos-sim
        proj_loss = 0.
        for j, (z_j, z_tilde_j) in enumerate(zip(zs, zs_tilde)):
            z_tilde_j = torch.nn.functional.normalize(z_tilde_j, dim=-1) 
            z_j = torch.nn.functional.normalize(z_j, dim=-1) 
            proj_loss += mean_flat(-(z_j * z_tilde_j).sum(dim=-1))
        proj_loss /= N
        stats_neg_cos_sim[int(elem[:-3])] = proj_loss.item()

        # Part 2: L2 matrix loss
        proj_loss = 0.
        a_mat = F.normalize(zs @ zs.transpose(1, 2), dim=-1)
        a_tilde_mat = F.normalize(zs_tilde @ zs_tilde.transpose(1, 2), dim=-1)
        # Compute the element-wise loss
        proj_loss += torch.sqrt(F.mse_loss(a_mat, a_tilde_mat, reduction='mean'))
        stats_frobenius_l2_norm[int(elem[:-3])] = proj_loss.item()
    
    ckpt_stats[ckpt_path] = {"neg_cos_sim": stats_neg_cos_sim, "kernel_align_patch_l2": stats_frobenius_l2_norm}

In [None]:
ckpt_stats

In [None]:
# neg-cos-sim
plt.figure(figsize=(8, 6))
for k, v in ckpt_stats.items():
    plt.plot(list(v["neg_cos_sim"].keys()), list(v["neg_cos_sim"].values()), label=k, marker='o')
plt.legend()
plt.xlabel("Steps")
plt.ylabel("Neg Cosine Similarity")
plt.title("Neg Cosine Similarity v.s. Steps")
plt.show()

# kernel align patch l2
plt.figure(figsize=(8, 6))
for k, v in ckpt_stats.items():
    plt.plot(list(v["kernel_align_patch_l2"].keys()), list(v["kernel_align_patch_l2"].values()), label=k, marker='o')
plt.legend()
plt.xlabel("Steps")
plt.ylabel("Kernel Align Patch L2")
plt.title("Kernel Align Patch L2 v.s. Steps")
plt.show()