In [1]:
from itertools import islice

import torch
import clip
from tqdm import tqdm
from datasets import load_from_disk, Sequence, Value, Features

model, preprocess = clip.load("ViT-B/32", device="cuda:1")

coco_dataset_dict = load_from_disk("/data/qiaowei/coco2014/coco_caption_arrow/")
coco_train_dataset, coco_valid_dataset = coco_dataset_dict.values()

coca_features = Features(**coco_train_dataset.features, clip=Sequence(Value("float32"), length=512))

def get_batch(dataset, batch_size=32):
    dataset = iter(dataset)
    while batch := list(islice(dataset, batch_size)):
        yield batch

BATCH_SIZE = 32
DATASET = coco_valid_dataset

image_clip_encodes = []
for records in tqdm(get_batch(coco_valid_dataset, BATCH_SIZE), total=len(DATASET) // BATCH_SIZE):
    images = [i["image"] for i in records]
    processed_image = [preprocess(image).unsqueeze(0).to("cuda:1") for image in images]
    batch = torch.cat(processed_image, dim=0)
    image_clip_encodes.extend(model.encode_image(batch).squeeze().tolist())
coca_valid_dataset = coco_valid_dataset.add_column("clip", image_clip_encodes).cast(coca_features)

BATCH_SIZE = 32
DATASET = coco_train_dataset

image_clip_encodes = []
for records in tqdm(get_batch(coco_train_dataset, BATCH_SIZE), total=len(DATASET) // BATCH_SIZE):
    images = [i["image"] for i in records]
    processed_image = [preprocess(image).unsqueeze(0).to("cuda:1") for image in images]
    batch = torch.cat(processed_image, dim=0)
    image_clip_encodes.extend(model.encode_image(batch).squeeze().tolist())
coca_train_dataset = coco_train_dataset.add_column("clip", image_clip_encodes).cast(coca_features)

coca_dataset_dict = DatasetDict(train=coca_train_dataset, valid=coca_valid_dataset)
coca_dataset_dict.save_to_disk("data/coca_arrow")

In [5]:
from datasets import Dataset, load_from_disk

In [3]:
dt =Dataset.from_file("data/coca_train.arrow")

In [7]:
dt = load_from_disk("/data/qiaowei/coco2014/coco_caption_arrow/")

In [8]:
dt["train"]

Dataset({
    features: ['image', 'caption'],
    num_rows: 414113
})