In [1]:
%cd ..
%env CUDA_VISIBLE_DEVICES=3

/mnt/SSD/lengx/REPA
env: CUDA_VISIBLE_DEVICES=3


-----------------------------------

In [2]:
import torch
from models.sit import SiT_models
from models.original_sit import SiT_models as original_SiT_models

In [3]:
sit_repa = torch.load("pretrained_models/last.pt", map_location=torch.device('cpu'))
sit = torch.load("pretrained_models/SiT-XL-2-256.pt", map_location=torch.device('cpu'))

In [6]:
block_kwargs = {"fused_attn": False, "qk_norm": False}
sit_xl = SiT_models['SiT-XL/2'](
    latent_size=32,
    num_classes=1000,
    use_cfg=True,
    z_dims=[768],
    encoder_depth=8,
    **block_kwargs,
)
sit_xl.load_state_dict(sit_repa)

<All keys matched successfully>

In [8]:
original_sit_xl = original_SiT_models['SiT-XL/2'](
    input_size=32,
    num_classes=1000,
)
original_sit_xl.load_state_dict(sit)

<All keys matched successfully>

In [11]:
print(set(sit.keys()) - set(sit_repa.keys()))
print(set(sit_repa.keys()) & set(sit.keys()))
print(set(sit_repa.keys()) - set(sit.keys()))

set()
{'blocks.4.attn.proj.weight', 'blocks.13.attn.qkv.weight', 'blocks.20.mlp.fc1.bias', 'blocks.23.mlp.fc2.weight', 'blocks.24.attn.qkv.bias', 'blocks.0.mlp.fc2.weight', 'blocks.26.mlp.fc2.bias', 'y_embedder.embedding_table.weight', 'blocks.27.attn.proj.bias', 'blocks.23.mlp.fc1.weight', 'blocks.6.attn.proj.bias', 'blocks.0.adaLN_modulation.1.weight', 'blocks.20.adaLN_modulation.1.weight', 'blocks.24.adaLN_modulation.1.weight', 'blocks.6.mlp.fc1.weight', 'blocks.24.mlp.fc2.bias', 'blocks.16.mlp.fc1.weight', 'blocks.5.mlp.fc1.weight', 'blocks.17.adaLN_modulation.1.weight', 'blocks.2.attn.qkv.weight', 'blocks.6.mlp.fc2.bias', 'blocks.14.mlp.fc1.weight', 'blocks.8.mlp.fc1.weight', 'blocks.22.mlp.fc2.bias', 'blocks.24.attn.proj.bias', 'blocks.22.adaLN_modulation.1.weight', 'blocks.23.attn.qkv.weight', 'blocks.9.mlp.fc1.bias', 'blocks.27.adaLN_modulation.1.weight', 'blocks.19.mlp.fc1.bias', 'blocks.0.mlp.fc1.weight', 'final_layer.linear.weight', 'blocks.18.attn.proj.bias', 'blocks.0.attn

In [12]:
for elem in set(sit_repa.keys()) & set(sit.keys()):
    if sit[elem].shape != sit_repa[elem].shape:
        print(f"SiT: {elem} - {sit[elem].shape} === SiT-RPEA: {elem} - {sit_repa[elem].shape}")

In [10]:
for elem in (set(sit_repa.keys()) - set(sit.keys())):
    sit[elem] = torch.rand_like(sit_repa[elem])


In [7]:
lst = []
for k, v in sit.items():
    if "final_layer.linear" in k:
        # sit[k] = v.chunk(2, dim=0)[0]
        print(k, v.shape)
        print(v.chunk(2, dim=0)[0].shape)
        lst.append(k)

for k in lst:
    sit[k] = sit[k].chunk(2, dim=0)[0]

final_layer.linear.weight torch.Size([32, 1152])
torch.Size([16, 1152])
final_layer.linear.bias torch.Size([32])
torch.Size([16])


In [13]:
torch.save(sit, "pretrained_models/SiT-XL-2-256-fixed.pt")

--------------------------------------------------

In [5]:
### From above, we know that the original SiT implementation is different from the REPA version, sigma prediction etc.
### Let's investigate more about the original SiT features, and its alignment with the DINOv2 features

### Improts
import torch
import random
import timm
from models.sit import SiT_models
from diffusers import AutoencoderKL
import torch.nn.functional as F
import json
import numpy as np
import h5py
import os
import PIL
import json
import gc
import io
from torchvision.transforms import Normalize
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
import matplotlib.pyplot as plt
from tqdm import tqdm
from metrics import AlignmentMetrics

pyspng = None

### Helper funcs
def load_h5_file(hf, path):
    # Helper function to load files from h5 file
    if path.endswith('.png'):
        if pyspng is not None:
            rtn = pyspng.load(io.BytesIO(np.array(hf[path])))
        else:
            rtn = np.array(PIL.Image.open(io.BytesIO(np.array(hf[path]))))
        rtn = rtn.reshape(*rtn.shape[:2], -1).transpose(2, 0, 1)
    elif path.endswith('.json'):
        rtn = json.loads(np.array(hf[path]).tobytes().decode('utf-8'))
    elif path.endswith('.npy'):
        rtn= np.array(hf[path])
    else:
        raise ValueError('Unknown file type: {}'.format(path))
    return rtn

def preprocess_raw_image(x, enc_type):
    x = x / 255.
    x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
    x = torch.nn.functional.interpolate(x, 224, mode='bicubic')
    x = torch.nn.functional.interpolate(x, 224, mode='bicubic')
    return x

def interpolant(t):
    alpha_t = 1 - t
    sigma_t = t
    d_alpha_t = -1
    d_sigma_t = 1
    return alpha_t, sigma_t, d_alpha_t, d_sigma_t

def mean_flat(x):
    return torch.mean(x, dim=list(range(1, len(x.size()))))

@torch.no_grad()
def sample_posterior(moments, latents_scale=1., latents_bias=0.):
    mean, std = torch.chunk(moments, 2, dim=1)
    z = mean + std * torch.randn_like(mean)
    z = (z * latents_scale + latents_bias) 
    return z

latents_scale = torch.tensor(
    [0.18215, 0.18215, 0.18215, 0.18215]
).view(1, 4, 1, 1).to("cuda:0")
latents_bias = torch.tensor(
    [0., 0., 0., 0.]
).view(1, 4, 1, 1).to("cuda:0")

### Sample data
with open("data/images_h5.json", "r") as f:
    images_h5_cfg = json.load(f)
with open("data/vae-sd_h5.json", "r") as f:
    vae_h5_cfg = json.load(f)

N = 256
BS = 8
chosen_files = random.Random(42).sample(images_h5_cfg, N)
chosen_vaes = [elem.replace("img", "img-mean-std-").replace(".png", ".npy") for elem in chosen_files]

image_h5 = h5py.File("data/images.h5", "r")
vae_h5 = h5py.File("data/vae-sd.h5", "r")

### Labels...
fname = 'dataset.json'
labels = load_h5_file(vae_h5, fname)['labels']
labels = dict(labels)
labels = [labels[fname.replace('\\', '/')] for fname in chosen_vaes]
labels = np.array(labels)
labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim])

images = preprocess_raw_image(torch.stack([torch.from_numpy(load_h5_file(image_h5, elem)) for elem in chosen_files]), "dinov2-vit-b")
vaes = torch.stack([torch.from_numpy(load_h5_file(vae_h5, elem)) for elem in chosen_vaes])
labels = torch.from_numpy(labels)


### Prepare the external encoder
encoder = torch.hub.load('facebookresearch/dinov2', f'dinov2_vitg14')
del encoder.head
encoder.pos_embed.data = timm.layers.pos_embed.resample_abs_pos_embed(
    encoder.pos_embed.data, [16, 16],
)
encoder.head = torch.nn.Identity()
encoder = encoder.to("cuda:0")
encoder.eval()

Using cache found in /home/lengx/.cache/torch/hub/facebookresearch_dinov2_main


DinoVisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 1536, kernel_size=(14, 14), stride=(14, 14))
    (norm): Identity()
  )
  (blocks): ModuleList(
    (0-39): 40 x NestedTensorBlock(
      (norm1): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)
      (attn): MemEffAttention(
        (qkv): Linear(in_features=1536, out_features=4608, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=1536, out_features=1536, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): LayerScale()
      (drop_path1): Identity()
      (norm2): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)
      (mlp): SwiGLUFFNFused(
        (w12): Linear(in_features=1536, out_features=8192, bias=True)
        (w3): Linear(in_features=4096, out_features=1536, bias=True)
      )
      (ls2): LayerScale()
      (drop_path2): Identity()
    )
  )
  (norm): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)
  (head