-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
"label_emb.pt" file #12
Comments
Sorry for very late response due to my graduation. import os
import argparse
from tqdm import tqdm
import torch
from clip.clip import load, tokenize
def main(args):
device = "cuda" if torch.cuda.is_available() else "cpu"
model, _ = load(args.clip_path, device=device, jit=False)
model = model.eval()
unseen_labels = open(os.path.join(args.data_path, "Concepts81.txt")).readlines()
seen_labels = open(os.path.join(args.data_path, "Concepts925.txt")).readlines()
label1006 = seen_labels + unseen_labels
label_token = torch.cat([tokenize(f"There is a {c.strip()} in the scene") for c in label1006]).to(device)
label_embed = torch.zeros((label_token.shape[0], 768))
with torch.no_grad():
for i, label in enumerate(tqdm(label_token)):
label_embed[i] = model.encode_text(label.unsqueeze(0))
torch.save(label_embed, os.path.join(args.data_path, "label_emb_nus.pt"))
print("Embedding Shape:", label_embed.shape)
print("Done!")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data-path", type=str, default=None)
parser.add_argument("--clip-path", type=str, default=None)
args = parser.parse_args()
main(args) Best, |
thank you so much! |
Closed
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Nice job! I would like to ask how the "label_emb.pt" file came about and what to do if I want to generate the "label_emb.pt" file myself?
您好! 我想请问一下,"label_emb.pt"文件是怎么来的,如果要自己生成"label_emb.pt"文件该怎么做?
Thanks! 谢谢!
The text was updated successfully, but these errors were encountered: