In [1]:
import clip
from tqdm import tqdm
from datasets import load_from_disk, Sequence, Value, Features

In [2]:
model, preprocess = clip.load("ViT-B/32", device="cuda:1")

In [3]:
coco_dataset_dict = load_from_disk("/data/qiaowei/coco2014/coco_caption_arrow/")
coco_train_dataset, coco_valid_dataset = coco_dataset_dict.values()

In [4]:
coca_features = Features(**coco_train_dataset.features, clip=Sequence(Value("float32"), length=512))
coco_valid_dataset = coco_valid_dataset.select(range(50))

In [5]:
image_clip_encodes = []
for record in tqdm(coco_valid_dataset):
    image = record["image"]
    processed_image = preprocess(image).unsqueeze(0).to("cuda:1")
    image_clip_encodes.append(model.encode_image(processed_image).squeeze().tolist())

100%|██████████| 50/50 [00:01<00:00, 36.27it/s]


In [6]:
coca_valid_dataset = coco_valid_dataset.add_column("clip", image_clip_encodes).cast(coca_features)

Loading cached processed dataset at /data/qiaowei/coco2014/coco_caption_arrow/valid/cache-1fb7a102643c7c24.arrow
Loading cached processed dataset at /data/qiaowei/coco2014/coco_caption_arrow/valid/cache-551d57941a0f5ed2.arrow


In [7]:
import torch
from transformers import AutoTokenizer

In [8]:
tokenizer = AutoTokenizer.from_pretrained("gpt2")

In [9]:
class CocaCollator:

    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
        tokenizer.pad_token = tokenizer.eos_token
    
    def __call__(self, records):
        texts = [record["caption"] for record in records]
        image_feature = torch.tensor([record["clip"] for record in records], dtype=torch.float)
        encodes = self.tokenizer(texts, padding=True, max_length=64, truncation=True, return_tensors="pt")
        encodes["image_feature"] = image_feature
        return encodes

In [10]:
collator = CocaCollator(tokenizer)

In [11]:
from itertools import islice

In [12]:
records = list(islice(coca_valid_dataset, 8))

In [13]:
batch = collator(records)

In [14]:
from coca.caption_model import CapitionModel

In [15]:
model = CapitionModel(8)

In [16]:
model.training

True

In [17]:
model(**batch)

CausalLMOutputWithCrossAttentions(loss=tensor(6.9112, grad_fn=<NllLossBackward>), logits=tensor([[[ -35.3246,  -34.9186,  -38.7027,  ...,  -42.8423,  -42.2760,
           -35.2916],
         [ -89.9451,  -88.8907,  -98.3121,  ...,  -99.3085,  -98.2210,
           -92.3134],
         [ -90.0260,  -89.1444,  -96.6579,  ...,  -97.8730,  -99.7618,
           -91.4128],
         ...,
         [ -92.9337,  -85.1563,  -88.8080,  ..., -108.5603, -109.0061,
           -93.4448],
         [ -92.8740,  -85.0904,  -88.7367,  ..., -108.4757, -108.9415,
           -93.4318],
         [ -92.7794,  -84.9884,  -88.6359,  ..., -108.3597, -108.8512,
           -93.3601]],

        [[ -32.7503,  -32.4580,  -36.0390,  ...,  -40.2331,  -39.7657,
           -32.7707],
         [ -79.5937,  -79.3641,  -85.7301,  ...,  -91.7535,  -87.7830,
           -82.0199],
         [ -96.4354,  -96.4151, -101.7587,  ..., -108.6933, -106.5354,
           -98.9826],
         ...,
         [ -96.4525,  -88.7787,  -92.2619,  