In [24]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.chdir('../')

### Stable Diffusion Model Init.

In [32]:
from txt2img_latent import get_parser, load_model_from_config, chunk

parser = get_parser()

# 2) args_list 정의 (원하는 인자들을 문자열 리스트로)
args_list = [
    "--config", "configs/stable-diffusion/v1-inference.yaml",
    "--ckpt", "models/ldm/stable-diffusion-v1/sd-v1-4.ckpt",
    "--H", "512",
    "--W", "512",
    "--C", "4",
    "--f", "8",
]

# 3) parse_args() 실행
opt = parser.parse_args(args_list)

import torch
from omegaconf import OmegaConf

config = OmegaConf.load(f"{opt.config}")
model = load_model_from_config(config, f"{opt.ckpt}")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)
print('done')

Loading model from models/ldm/stable-diffusion-v1/sd-v1-4.ckpt
Global Step: 470000
LatentDiffusion: Running in eps-prediction mode
DiffusionWrapper has 859.52 M params.
making attention of type 'vanilla' with 512 in_channels
Working with z of shape (1, 4, 32, 32) = 4096 dimensions.
making attention of type 'vanilla' with 512 in_channels
done


### CLIP Model Init.

In [None]:
import torch
import clip
from PIL import Image

device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model, preprocess = clip.load("ViT-B/32", device=device)
print('done')

### Decode and to PIL

In [46]:
import numpy as np
import os
import torch
from einops import rearrange
import matplotlib.pyplot as plt
from PIL import Image

# 1) decode() 안에서 PIL Image까지 바로 꺼내도록 살짝 손봐두면 편함
def decode_to_pil(image):
    x = model.decode_first_stage(image.to(device))
    x = torch.clamp((x + 1.0) / 2.0, 0, 1)          # [0,1]
    x = (x * 255).byte().cpu()                      # uint8
    # BCHW  ->  list[HWC]
    imgs = [Image.fromarray(t.permute(1,2,0).numpy()) for t in x]
    return imgs

def decode(image):
    imgs = decode_to_pil(image)
    plt.figure(figsize=[18, 5])
    for i, img in enumerate(imgs):
        plt.subplot(1, len(imgs), i+1)
        plt.imshow(img)
        plt.xticks([])
        plt.yticks([])
    plt.show()

### Define Calc Function

In [50]:
import os
import torch
import torch.nn.functional as F
from tqdm import tqdm

def calc_cosim(pt_path, save_path):
    os.makedirs(save_path, exist_ok=True)
    files = [os.path.join(pt_path, f) for f in os.listdir(pt_path)]

    for file in tqdm(files[:10]):
        basename = os.path.basename(file)
        data = torch.load(file)
        texts = clip.tokenize(data['text']).to(device)
        images = torch.stack([preprocess(image) for image in decode_to_pil(data['image'])]).to(device)
        
        with torch.no_grad():
            image_features = clip_model.encode_image(images)
            text_features = clip_model.encode_text(texts)
            cosim = F.cosine_similarity(image_features, text_features)
        save_file = os.path.join(save_path, basename)
        torch.save({"cosim": cosim}, save_file)


In [None]:


root_dir = '/data/archive/sd-v1-4'
model_name = 'dpm_solver++'
steps = 200

for scale in [1.5, 3.5, 5.5, 7.5, 9.5]:
    path = os.path.join(root_dir, f"{model_name}_steps{steps}_scale{scale}")
    calc_cosim(path, path + '_clip')

  0%|          | 0/10 [00:00<?, ?it/s]

 80%|████████  | 8/10 [00:02<00:00,  3.13it/s]

In [18]:
import os
root_dir = '/data/archive/sd-v1-4'
for scale in [1.5, 3.5, 5.5, 7.5, 9.5]:
    if scale in [1.5, 7.5]:
        models = ['dpm_solver++', 'uni_pc_bh2', 'dpm_solver_v3', 'rbf_order2', 'rbf_order3']
    else:
        models = ['dpm_solver++', 'uni_pc_bh2', 'rbf_order2', 'rbf_order3']
    for model in models:
        for steps in [5, 6, 8, 10, 12, 15, 20]:
            path = os.path.join(root_dir, f"{model}_steps{steps}_scale{scale}")
            if os.path.exists(path):
                files = [os.path.join(path, f) for f in os.listdir(path)]
                files = sorted(files)
                print(files[:3])
                for file in files[:3]:
                    data = torch.load(file)
                    print(data.keys())

    

['/data/archive/sd-v1-4/dpm_solver++_steps5_scale1.5/0.pt', '/data/archive/sd-v1-4/dpm_solver++_steps5_scale1.5/1.pt', '/data/archive/sd-v1-4/dpm_solver++_steps5_scale1.5/10.pt']
dict_keys(['latent', 'image', 'text'])
dict_keys(['latent', 'image', 'text'])
dict_keys(['latent', 'image', 'text'])
['/data/archive/sd-v1-4/dpm_solver++_steps6_scale1.5/0.pt', '/data/archive/sd-v1-4/dpm_solver++_steps6_scale1.5/1.pt', '/data/archive/sd-v1-4/dpm_solver++_steps6_scale1.5/10.pt']
dict_keys(['latent', 'image', 'text'])
dict_keys(['latent', 'image', 'text'])
dict_keys(['latent', 'image', 'text'])
['/data/archive/sd-v1-4/dpm_solver++_steps8_scale1.5/0.pt', '/data/archive/sd-v1-4/dpm_solver++_steps8_scale1.5/1.pt', '/data/archive/sd-v1-4/dpm_solver++_steps8_scale1.5/10.pt']
dict_keys(['latent', 'image', 'text'])
dict_keys(['latent', 'image', 'text'])
dict_keys(['latent', 'image', 'text'])
['/data/archive/sd-v1-4/dpm_solver++_steps10_scale1.5/0.pt', '/data/archive/sd-v1-4/dpm_solver++_steps10_scale1.