1. specify `sys.path.append` as your project directory path

2. change `Config.dataset` as different dataset names

3. set `Config.arch='ViT-L/14'` and `f_classifier='./cache/vocabulary_classifier_L.pth'` for ViT-L architecture

In [22]:
import sys
sys.path.append('/home/sheng/sheng-eatamath/S3A/')

import os
import json
import re
import time
import pickle
from pprint import pprint
from tqdm import tqdm
import numpy as np

import torch 
import torch.nn as nn
import torch.nn.functional as F
import torchvision

import model as clip
from model import tokenize
from data.build_dataset import build_transform
from data.imagenet_datasets import get_datasets_rzsc
from data.vocab import get_vocab, Vocab

class Config:
    device = 'cuda:1'
    arch = 'ViT-B/16'
    ### dataset name
    dataset = 'make_entity13'
    n_sampled_classes = 100 ### set num of sampled classes for ImageNet-100
    input_size = 224
    batch_size = 256
    clip_checkpoint = None ### whether to use clip checkpoint
    f_classifier = './cache/vocabulary_classifier.pth' ### precomputed 21k CLIP vocabulary classifier
    templates_name = 'templates' ### CLIP template file name
    seed = 0
    image_mean = (0.48145466, 0.4578275, 0.40821073)
    image_std = (0.26862954, 0.26130258, 0.27577711)
    
args = Config()

def load_templates(args):
    with open(f'../{args.templates_name}.json', 'rb') as f:
        templates = json.load(f)['imagenet']
    return templates

templates = load_templates(args)
vocab = get_vocab()

get_vocab in21k


In [10]:
def load_clip2(args):
    model = clip.load(args.arch, device=args.device)
    if args.clip_checkpoint:
        model.load_state_dict({k[len('model.'):]:v for k, v in torch.load(args.clip_checkpoint, map_location='cpu')['model_ema'].items()}, strict=False)
    model.to(args.device).eval()
    input_resolution = model.visual.input_resolution
    context_length = model.context_length
    vocab_size = model.vocab_size

    print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
    print("Input resolution:", input_resolution)
    print("Context length:", context_length)
    print("Vocab size:", vocab_size)
    return model


def build_classifier(args, model, templates, vocab_classnames, parent_classnames=None):
    batch_size = 64
    with torch.no_grad():
        zeroshot_weights = []
        assert parent_classnames is None
        with tqdm(total=len(vocab_classnames)//batch_size) as pbar:
            for classname_set in np.array_split(vocab_classnames, len(vocab_classnames)//batch_size):
                texts = [template.format(classname) for classname in classname_set for template in templates] #format with class
                texts = tokenize(texts).to(args.device) #tokenize
                class_embeddings = model.encode_text(texts).float() #embed with text encoder
                class_embeddings = class_embeddings.view(-1, len(templates), class_embeddings.size(-1))
                class_embeddings = F.normalize(class_embeddings, dim=-1)
                class_embedding = class_embeddings.mean(dim=1)
                class_embedding /= class_embedding.norm(dim=-1, keepdim=True)
                zeroshot_weights.append(class_embedding.cpu())
                pbar.update(1)
    classifier = torch.cat(zeroshot_weights, dim=0)
    return classifier

In [23]:
transform_val = build_transform(is_train=False, args=args, train_config=None)
dataset = get_datasets_rzsc(args, vocab, is_train=True, transform=transform_val, seed=0)
loader_val = torch.utils.data.DataLoader(dataset, num_workers=8, batch_size=args.batch_size, shuffle=False)
print('dataset size', len(dataset))
model = load_clip2(args)
classifier = build_classifier(args, model, templates, vocab.classnames)
torch.save(classifier, f'{args.f_classifier}')

dataset size 334718
===load nonjit===
missing keys:
['visual.projection_head.0.weight', 'visual.projection_head.0.bias', 'visual.projection_head.2.weight', 'visual.projection_head.2.bias']
Model parameters: 150,408,193
Input resolution: 224
Context length: 77
Vocab size: 49408


100%|██████████| 313/313 [03:21<00:00,  1.55it/s]
