In [9]:
import torch
import os
from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer

pretrained_model_name_or_path = "/root/autodl-tmp/stable_diffusion/stable-diffusion-v1-5"
first_token = "animal"
second_token = "cat"
mask_k = 576

tokenizer = CLIPTokenizer.from_pretrained(
    pretrained_model_name_or_path,
    subfolder="tokenizer",
)
text_encoder = CLIPTextModel.from_pretrained(
    pretrained_model_name_or_path, subfolder="text_encoder"
)

In [10]:
token_ids = tokenizer.encode(first_token, add_special_tokens=False)
if len(token_ids) > 1:
    raise ValueError("The initializer token must be a single token.")

initializer_token_id = token_ids[0]
token_ids_second = tokenizer.encode(second_token, add_special_tokens=False)
if len(token_ids_second) > 1:
    raise ValueError("The second token must be a single token.")
second_token_id = token_ids_second[0]

token_embeds = text_encoder.get_input_embeddings().weight.data
initialization_embedding = token_embeds[initializer_token_id]
second_embedding = token_embeds[second_token_id]

vocab_dist_first = torch.mm(token_embeds, initialization_embedding.unsqueeze(1)).squeeze(1)
_, vocab_ids_first = torch.topk(vocab_dist_first, mask_k, 0, True)
if second_token_id in vocab_ids_first:
    print(f"The {mask_k} candidates for {first_token} already contain {second_token}")
else:
    print(f"The {mask_k} candidates for {first_token} do not contain {second_token}")

The 576 candidates for animal already contain cat
