In [1]:
import torch
import skimage.io as io
import clip
from PIL import Image
import pickle
import json
import os
from tqdm import tqdm
import argparse

In [2]:
#先使用clip中的process对image进行加载，然后用clip_model将imageencode成prefix

def main(out_path, file_path, clip_model_type = "ViT-B/32"):
    # clip_model_type = "ViT-B/32"
    device = torch.device('cuda:0')
    clip_model_name = clip_model_type.replace('/', '_')
    out_train_path = os.path.join(out_path,f'oscar_split_{clip_model_name}_train.pkl')
    out_test_path = os.path.join(out_path,f'oscar_split_{clip_model_name}_test.pkl')
    clip_model, preprocess = clip.load(clip_model_type, device=device, jit=False)
    filename,captions = [],[]
    with open(os.path.join(file_path,'captions.txt'), 'r') as file:
        lines = file.readlines()
        for line in lines[1:]:
            temp = line.split('.jpg,')
            filename.append(temp[0] + '.jpg')
            captions.append(temp[1][:-1])

    print("%0d captions loaded from json " % len(filename))
    train_embeddings,train_captions = [],[]
    total_size = len(filename)
    train_size = int(total_size*0.9)
    for i in tqdm(range(train_size)):
        d = {'image_id':filename[i],'caption':captions[i]}
        file = os.path.join(file_path,'images',filename[i])
        image = io.imread(file)
        image = preprocess(Image.fromarray(image)).unsqueeze(0).to(device)
        with torch.no_grad():
            prefix = clip_model.encode_image(image).to(device, dtype=torch.float32)
            # prefix = clip_model.encode_image(image).cpu()
        d["clip_embedding"] = i
        train_embeddings.append(prefix)
        train_captions.append(d)
    with open(out_train_path, 'wb') as f:
        pickle.dump({"clip_embedding": torch.cat(train_embeddings, dim=0), "captions": train_captions}, f)

    print('Done')
    print("%0d embeddings saved " % train_size)
    return 0

In [5]:
main(clip_model_type = "ViT-B/32",out_path='/root/image caption/data/flicker8k',file_path='/root/image caption/flickr8k')

40455 captions loaded from json 


100%|██████████| 36409/36409 [07:48<00:00, 77.63it/s]


Done
36409 embeddings saved 


0

In [6]:
main(clip_model_type = 'RN50x4',out_path='/root/image caption/data/flicker8k',file_path='/root/image caption/flickr8k')

40455 captions loaded from json 


100%|██████████| 36409/36409 [08:50<00:00, 68.68it/s]


Done
36409 embeddings saved 


0

In [3]:
main(clip_model_type = "ViT-B/32",out_path='/root/image caption/data/flicker30k',file_path='/root/image caption/flicker30k')

158914 captions loaded from json 


100%|██████████| 143022/143022 [32:55<00:00, 72.41it/s]


Done
143022 embeddings saved 


0

In [4]:
main(clip_model_type = "RN50x4",out_path='/root/image caption/data/flicker30k',file_path='/root/image caption/flicker30k')

158914 captions loaded from json 


100%|██████████| 143022/143022 [38:08<00:00, 62.49it/s]


Done
143022 embeddings saved 


0