In [1]:
!nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv,noheader

NVIDIA A100-PCIE-40GB, 40960 MiB, 19652 MiB
NVIDIA A100-PCIE-40GB, 40960 MiB, 28478 MiB
NVIDIA A100-PCIE-40GB, 40960 MiB, 19578 MiB


In [16]:
# compute CLIP-space cosine-similarity distance
import torch
import os
import clip
import pathlib
import tqdm
import numpy as np
import warnings
import json
import sklearn.preprocessing
from diffusers import StableDiffusionPipeline
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from PIL import Image
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from packaging import version
from shutil import rmtree

SEED = None
DEVICE = torch.device("cuda:1") if torch.cuda.is_available() else torch.device("cpu")

pretrained_model_name_or_path = "/root/autodl-tmp/stable_diffusion/stable-diffusion-v1-5"
learned_embeds_path = "/root/autodl-tmp/textual_inversion/merged_embeddings/custom_chair&custom_cat/learned_embeds_factor=0.12&learned_embeds_factor=0.5.bin"
original_image_dir_1 = "/root/autodl-tmp/textual_inversion/data/chair"
original_image_dir_2 = "/root/autodl-tmp/textual_inversion/data/cat"
all_embedding_path, embeds_suffix = os.path.split(learned_embeds_path)
embeds_name, _ = os.path.splitext(embeds_suffix)
_, all_dataset_name = os.path.split(all_embedding_path)
dataset_name_list = all_dataset_name.split("&")

tokenizer = CLIPTokenizer.from_pretrained(
    pretrained_model_name_or_path,
    subfolder="tokenizer",
)
text_encoder = CLIPTextModel.from_pretrained(
    pretrained_model_name_or_path, subfolder="text_encoder", torch_dtype=torch.float16
)

In [17]:
loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")
trained_token_list = list(loaded_learned_embeds.keys())
dtype = text_encoder.get_input_embeddings().weight.dtype
for i, trained_token in enumerate(trained_token_list):
    # separate token and the embeds
    embeds = loaded_learned_embeds[trained_token]
    # cast to dtype of text_encoder
    embeds.to(dtype)
    # add the token in tokenizer
    num_added_tokens = tokenizer.add_tokens(trained_token)
    # resize the token embeddings
    text_encoder.resize_token_embeddings(len(tokenizer))
    # get the id for the token and assign the embeds
    token_id = tokenizer.convert_tokens_to_ids(trained_token)
    text_encoder.get_input_embeddings().weight.data[token_id] = embeds
    if num_added_tokens == 0:
        raise ValueError(f"The tokenizer already contains the token {trained_token}. "
                         "Please pass a different `token` that is not already in the tokenizer.")
    print(f"placeholder token for dataset {dataset_name_list[i]}: {trained_token}")

pipe = StableDiffusionPipeline.from_pretrained(pretrained_model_name_or_path,
                                               torch_dtype=torch.float16, 
                                               text_encoder=text_encoder, 
                                               tokenizer=tokenizer).to(DEVICE)

placeholder token for dataset custom_chair: <custom_chair>
placeholder token for dataset custom_cat: <custom_cat>


In [18]:
prompt = "A photo of {} and {}".format(trained_token_list[0], trained_token_list[1]) # Or "A photo depicts <*>"
generator = None if SEED is None else torch.Generator(
            device=DEVICE).manual_seed(SEED)
N = 16  # number of random generated images
clip_image_dir = os.path.join(all_embedding_path, "clip_images")
os.makedirs(clip_image_dir, exist_ok=True)

for n in range(N):
    image_n = pipe(prompt, num_inference_steps=50, guidance_scale=7.5, 
                   generator=generator).images[0]
    print(f"Generating images: {n + 1}/{N}", end="\r")
    image_n_path = os.path.join(clip_image_dir, "{}_{}.png".format(prompt, n + 1))
    image_n.save(image_n_path)

100%|██████████| 50/50 [00:04<00:00, 12.41it/s]


Generating images: 1/16

100%|██████████| 50/50 [00:03<00:00, 12.57it/s]


Generating images: 2/16

100%|██████████| 50/50 [00:03<00:00, 12.64it/s]


Generating images: 3/16

100%|██████████| 50/50 [00:03<00:00, 12.73it/s]


Generating images: 4/16

100%|██████████| 50/50 [00:03<00:00, 12.65it/s]


Generating images: 5/16

100%|██████████| 50/50 [00:03<00:00, 12.73it/s]


Generating images: 6/16

100%|██████████| 50/50 [00:03<00:00, 12.51it/s]


Generating images: 7/16

100%|██████████| 50/50 [00:03<00:00, 12.50it/s]


Generating images: 8/16

100%|██████████| 50/50 [00:03<00:00, 12.60it/s]


Generating images: 9/16

100%|██████████| 50/50 [00:03<00:00, 12.76it/s]


Generating images: 10/16

100%|██████████| 50/50 [00:03<00:00, 12.78it/s]


Generating images: 11/16

100%|██████████| 50/50 [00:04<00:00, 12.28it/s]


Generating images: 12/16

100%|██████████| 50/50 [00:03<00:00, 12.56it/s]


Generating images: 13/16

100%|██████████| 50/50 [00:03<00:00, 12.65it/s]


Generating images: 14/16

100%|██████████| 50/50 [00:03<00:00, 12.63it/s]


Generating images: 15/16

100%|██████████| 50/50 [00:03<00:00, 12.61it/s]


Generating images: 16/16

In [19]:
clip_images_path_list = [os.path.join(clip_image_dir, path) for path in os.listdir(clip_image_dir)
                         if path.endswith(('.png', '.jpg', '.jpeg', '.tiff'))]
original_images_path_list_1 = [os.path.join(original_image_dir_1, path) for path in os.listdir(
                        original_image_dir_1) if path.endswith(('.png', '.jpg', '.jpeg', '.tiff'))]
num_original_images_1 = len(original_images_path_list_1)
original_images_path_list_2 = [os.path.join(original_image_dir_2, path) for path in os.listdir(
                        original_image_dir_2) if path.endswith(('.png', '.jpg', '.jpeg', '.tiff'))]
num_original_images_2 = len(original_images_path_list_2)

In [20]:
class CLIPImageDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = data
        # only 224x224 ViT-B/32 supported for now
        self.preprocess = self._transform_test(224)

    def _transform_test(self, n_px):
        return Compose([
            Resize(n_px, interpolation=Image.BICUBIC),
            CenterCrop(n_px),
            lambda image: image.convert("RGB"),
            ToTensor(),
            Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
        ])

    def __getitem__(self, idx):
        c_data = self.data[idx]
        image = Image.open(c_data)
        image = self.preprocess(image)
        return {'image': image}

    def __len__(self):
        return len(self.data)

In [21]:
def extract_all_images(images, model, device, batch_size=64, num_workers=8):
    data = torch.utils.data.DataLoader(
        CLIPImageDataset(images),
        batch_size=batch_size, num_workers=num_workers, shuffle=False)
    all_image_features = []
    with torch.no_grad():
        for b in tqdm.tqdm(data):
            b = b['image'].to(device)
            b = b.to(torch.float16)
            all_image_features.append(model.encode_image(b).cpu().numpy())
    all_image_features = np.vstack(all_image_features)
    return all_image_features

def get_clip_score(model, clip_images, original_images, device, w=1.0):
    if isinstance(clip_images, list):
        # need to extract image features
        clip_images = extract_all_images(clip_images, model, device)
    if isinstance(original_images, list):
        # need to extract image features
        original_images = extract_all_images(original_images, model, device)
    
    # as of numpy 1.21, normalize doesn't work properly for float16
    if version.parse(np.__version__) < version.parse('1.21'):
        clip_images = sklearn.preprocessing.normalize(clip_images, axis=1)
        original_images = sklearn.preprocessing.normalize(original_images, axis=1)
    else:
        warnings.warn(
            'due to a numerical instability, new numpy normalization is slightly different than' 
            'paper results. To exactly replicate paper results, please use numpy version less' 
            'than 1.21, e.g., 1.20.3.')
        clip_images = clip_images / np.sqrt(np.sum(clip_images ** 2, axis=1, keepdims=True))
        original_images = original_images / np.sqrt(np.sum(original_images ** 2, axis=1, 
                                                           keepdims=True))
    
    per = w * np.clip(np.dot(clip_images, original_images.T), 0, None)
    return np.mean(per)

In [22]:
clip_model, clip_transform = clip.load("ViT-B/32", device=DEVICE, jit=False)
clip_model.eval()

clip_features = extract_all_images(clip_images_path_list, clip_model, DEVICE, batch_size=N, 
                                   num_workers=8)
original_features_1 = extract_all_images(original_images_path_list_1, clip_model, DEVICE, 
                                       batch_size=num_original_images_1, num_workers=8)
original_features_2 = extract_all_images(original_images_path_list_2, clip_model, DEVICE, 
                                       batch_size=num_original_images_2, num_workers=8)

# Compute pair-wise Clip-space cosine similarity
final_score_1 = get_clip_score(clip_model, clip_features, original_features_1, DEVICE)
final_score_2 = get_clip_score(clip_model, clip_features, original_features_2, DEVICE)
clip_score_dir = f"{all_embedding_path}/i2i_score"
os.makedirs(clip_score_dir, exist_ok=True)
clip_score_path = f"{clip_score_dir}/{embeds_name}_i2i_score.txt"
with open(clip_score_path,"w") as f:
    f.write("CLIP image2image score: {}&{}".format(final_score_1, final_score_2))
print("END!!! CLIP image2image score for {} is: {}&{}".format(embeds_name, final_score_1, final_score_2))
rmtree(clip_image_dir)

  Resize(n_px, interpolation=Image.BICUBIC),
100%|██████████| 1/1 [00:01<00:00,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00,  1.01it/s]
100%|██████████| 1/1 [00:02<00:00,  2.76s/it]

END!!! CLIP image2image score for learned_embeds_factor=0.12&learned_embeds_factor=0.5 is: 0.51806640625&0.65576171875



