Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

"label_emb.pt" file #12

Closed
HaoliangZhou opened this issue Jun 7, 2023 · 2 comments
Closed

"label_emb.pt" file #12

HaoliangZhou opened this issue Jun 7, 2023 · 2 comments

Comments

@HaoliangZhou
Copy link

Nice job! I would like to ask how the "label_emb.pt" file came about and what to do if I want to generate the "label_emb.pt" file myself?
您好! 我想请问一下,"label_emb.pt"文件是怎么来的,如果要自己生成"label_emb.pt"文件该怎么做?
Thanks! 谢谢!

@sunanhe
Copy link
Owner

sunanhe commented Jul 4, 2023

Sorry for very late response due to my graduation.
We generate the label embedding file using the following code.

import os
import argparse
from tqdm import tqdm 

import torch

from clip.clip import load, tokenize

def main(args):

    device = "cuda" if torch.cuda.is_available() else "cpu"

    model, _ = load(args.clip_path, device=device, jit=False)
    model = model.eval()

    unseen_labels = open(os.path.join(args.data_path, "Concepts81.txt")).readlines()
    seen_labels = open(os.path.join(args.data_path, "Concepts925.txt")).readlines()

    label1006 = seen_labels + unseen_labels

    label_token = torch.cat([tokenize(f"There is a {c.strip()} in the scene") for c in label1006]).to(device)
    label_embed = torch.zeros((label_token.shape[0], 768))

    with torch.no_grad():
        for i, label in enumerate(tqdm(label_token)):
            label_embed[i] = model.encode_text(label.unsqueeze(0))

    torch.save(label_embed, os.path.join(args.data_path, "label_emb_nus.pt"))

    print("Embedding Shape:", label_embed.shape)
    print("Done!")


if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument("--data-path", type=str, default=None)
    parser.add_argument("--clip-path", type=str, default=None)
    args = parser.parse_args()

    main(args)

Best,
Sunan

@sunanhe sunanhe closed this as completed Jul 4, 2023
@HaoliangZhou
Copy link
Author

thank you so much!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants