In [None]:
!nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv,noheader

In [1]:
import torch
import os
from diffusers import StableDiffusionPipeline
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer

pretrained_model_name_or_path = "/root/autodl-tmp/stable_diffusion/stable-diffusion-v1-5"
dist_type = "Vector_dot"

tokenizer = CLIPTokenizer.from_pretrained(
    pretrained_model_name_or_path,
    subfolder="tokenizer",
)
text_encoder = CLIPTextModel.from_pretrained(
    pretrained_model_name_or_path, 
    subfolder="text_encoder",
)

  from .autonotebook import tqdm as notebook_tqdm


In [9]:
initialization_token = "chair" # change here when using other datasets
vocab_num = len(tokenizer)
vocab_embedding = text_encoder.get_input_embeddings().weight.data
token_ids = tokenizer.encode(initialization_token, add_special_tokens=False)
# Check if initializer_token is a single token or a sequence of tokens
if len(token_ids) > 1:
    raise ValueError("The initializer token must be a single token.")
initializer_token_id = token_ids[0]
initialization_embedding = vocab_embedding[initializer_token_id]

In [10]:
# change M to get satisfying rank
M = 768

if dist_type == "Vector_dot":
    vocab_dist = torch.mm(vocab_embedding, initialization_embedding.unsqueeze(1)).squeeze(1)
    _, vocab_max_ids = torch.topk(vocab_dist, M, 0, True)
    candidate_embedding_matrix = vocab_embedding[vocab_max_ids]  # [args.mask_k, 768]
elif dist_type == "Cosine_similarity":
    embeds_matrix = initialization_embedding.unsqueeze(0).expand(vocab_num, 
                                                                 initialization_embedding.shape[0])
    vocab_dist = torch.cosine_similarity(embeds_matrix, vocab_embedding, 0)
    _, vocab_max_ids = torch.topk(vocab_dist, M, 0, True)
    candidate_embedding_matrix = vocab_embedding[vocab_max_ids]  # [args.mask_k, 768]
elif dist_type == "L2":
    embeds_matrix = initialization_embedding.unsqueeze(0).expand(vocab_num, 
                                                                 initialization_embedding.shape[0])
    residual_matrix = embeds_matrix - vocab_embedding
    vocab_dist = torch.norm(residual_matrix, 2, 1)
    _, vocab_min_ids = torch.topk(vocab_dist, M, 0, False)
    candidate_embedding_matrix = vocab_embedding[vocab_min_ids]  # [args.mask_k, 768]

# calculate the rank of the candidate_embedding_matrix
candidate_rank = torch.linalg.matrix_rank(candidate_embedding_matrix)
print(f"The rank is {candidate_rank}/768")

The rank is 768/768
