In [21]:
import os
import json
import numpy as np
from os import path

from tqdm import tqdm
import gzip
import regex as re

In [2]:
#Format the vinvl mapping
def vinvl2cc_mapping(vinvl_annotation, vinvl2cc, vinvl_mapping, missed_mapping):
    for k,v in tqdm(vinvl_annotation.items()):
        current_caption = v['caption']
        vinvl_id = k

        if vinvl_mapping.get(current_caption, None) is not None:
            cc_id = vinvl_mapping.get(current_caption, None)
            vinvl2cc[vinvl_id] = cc_id
        else:
            missed_mapping.append(vinvl_id)

In [3]:
#Pair the CLIP Image IDs
with open("/data/home/zmykevin/vinvl_data/CC/cc_objects_captions.json", "r") as f:
    cc_objects_captions = json.load(f)

In [4]:
#Load the original train and validation npy
data_path = "/fsx/lyuchen2/vilt_dataset/large_experiments/cmd/cc"
train_image_path = "/fsx/lyuchen2/vilt_dataset/large_experiments/cmd/cc/training"
val_image_path = "/fsx/lyuchen2/vilt_dataset/large_experiments/cmd/cc/validation"

train_annotation = np.load(os.path.join(data_path, "train_all.npy"), allow_pickle=True)
val_annotation = np.load(os.path.join(data_path, "val.npy"), allow_pickle=True)

In [5]:
# Create the mapping between id and captions
# vinvl_mapping = {}
# for k, v in cc_objects_captions.items():
#     vinvl_mapping[v['caption']] = k
cc_mapping = {}
cc_reverse_mapping = {}
for ann in train_annotation:
    image_id = ann['image_id']
    current_caption = ann['captions'][0]
    cc_mapping[current_caption] = image_id
    cc_reverse_mapping[image_id] = current_caption

for ann in val_annotation:
    image_id = ann['image_id']
    current_caption = ann['captions'][0]
    cc_mapping[current_caption] = image_id
    cc_reverse_mapping[image_id] = current_caption
# print(val_annotation[0])
# assert path.exists(os.path.join(val_image_path, str(val_annotation[0]['image_id'])))

In [46]:
#Check if we can findthe mapping
vinvl2cc = {}
missed_mapping = []

vinvl2cc_mapping(cc_objects_captions, vinvl2cc, cc_mapping, missed_mapping)

100%|██████████| 3116254/3116254 [00:08<00:00, 368906.53it/s]


In [47]:
print(len(vinvl2cc))

3055348


In [49]:
cc_objects_captions_updated = {}
for k,v in cc_objects_captions.items():
    vinvl_id = k
    vinvl_value = v
    #if vinvl2cc.get(vinvl_id, None) is not None:
    cc_id = vinvl2cc.get(vinvl_id, None)
    vinvl_value['cc_id'] = cc_id
    cc_objects_captions_updated[vinvl_id] = vinvl_value
    

In [50]:
for k,v in cc_objects_captions_updated.items():
    print(k)
    print(v)
    break

0
{'objects': 'bus bus bus bus building bus bus building shirt man shirt person road shirt person person person person man car man man person man bus person bus door person person person person pant person bus person window window person street person window man bus shirt man roof man man shirt person person boat hat balcony van pant bus sign roof', 'objects_no_rep': 'pant street building shirt man person hat car roof sign door van window road balcony boat bus', 'caption': 'a very typical bus station', 'cc_id': 2901536091}


In [51]:
with open("/data/home/zmykevin/vinvl_data/CC/cc_objects_captions.json", "w") as f:
    json.dump(cc_objects_captions_updated, f)

In [6]:
#Find the undetected vinvl data
missed_vinvl_data = {}
paired_cc_data = {}
for k, v in cc_objects_captions.items():
    vinvl_id = k
    if v['cc_id'] is None:
        missed_vinvl_data[k] = v
    else:
        paired_cc_data[v['cc_id']] = paired_cc_data.get(v['cc_id'], []) + [vinvl_id]
print(len(missed_vinvl_data))
#Multiple vinvl_id is mapped to one cc_id
print(len(paired_cc_data))

60906
2163343


In [9]:
#Pick the Test CLIP Set, 10000
import random

clip_test = {}
for k,v in paired_cc_data.items():
    if len(v) == 1:
        if random.uniform(0,1) > 0.5:
            clip_test[v[0]] = cc_objects_captions[v[0]]
    if len(clip_test) == 10000:
        break

with open("/data/home/zmykevin/vinvl_data/CC/cc_clip_test.json", "w") as f:
    json.dump(clip_test, f)

In [19]:
def default_bpe():
    return os.path.join("/checkpoints/zmykevin/models", "bpe_simple_vocab_16e6.txt.gz")

def bytes_to_unicode():
    """
    Returns list of utf-8 byte and a corresponding list of unicode strings.
    The reversible bpe codes work on unicode strings.
    This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
    When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
    This is a signficant percentage of your normal, say, 32K bpe vocab.
    To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
    And avoids mapping to whitespace/control characters the bpe code barfs on.
    """
    bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
    cs = bs[:]
    n = 0
    for b in range(2**8):
        if b not in bs:
            bs.append(b)
            cs.append(2**8+n)
            n += 1
    cs = [chr(n) for n in cs]
    return dict(zip(bs, cs))

class SimpleTokenizer(object):
    def __init__(self, bpe_path: str = default_bpe()):
        self.byte_encoder = bytes_to_unicode()
        self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
        merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
        merges = merges[1:49152-256-2+1]
        merges = [tuple(merge.split()) for merge in merges]
        vocab = list(bytes_to_unicode().values())
        vocab = vocab + [v+'</w>' for v in vocab]
        for merge in merges:
            vocab.append(''.join(merge))
        vocab.extend(['<|startoftext|>', '<|endoftext|>'])
        self.encoder = dict(zip(vocab, range(len(vocab))))
        self.decoder = {v: k for k, v in self.encoder.items()}
        self.bpe_ranks = dict(zip(merges, range(len(merges))))
        self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
        self.pat = re.compile(r"<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+", re.IGNORECASE)
        #self.pat = re.compile(r"<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]", re.IGNORECASE)
    def bpe(self, token):
        if token in self.cache:
            return self.cache[token]
        word = tuple(token[:-1]) + ( token[-1] + '</w>',)
        pairs = get_pairs(word)

        if not pairs:
            return token+'</w>'

        while True:
            bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
            if bigram not in self.bpe_ranks:
                break
            first, second = bigram
            new_word = []
            i = 0
            while i < len(word):
                try:
                    j = word.index(first, i)
                    new_word.extend(word[i:j])
                    i = j
                except:
                    new_word.extend(word[i:])
                    break

                if word[i] == first and i < len(word)-1 and word[i+1] == second:
                    new_word.append(first+second)
                    i += 2
                else:
                    new_word.append(word[i])
                    i += 1
            new_word = tuple(new_word)
            word = new_word
            if len(word) == 1:
                break
            else:
                pairs = get_pairs(word)
        word = ' '.join(word)
        self.cache[token] = word
        return word

    def encode(self, text):
        bpe_tokens = []
        text = whitespace_clean(basic_clean(text)).lower()
        for token in re.findall(self.pat, text):
            token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
            bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
        return bpe_tokens

    def decode(self, tokens):
        text = ''.join([self.decoder[token] for token in tokens])
        text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
        return text

In [22]:
tokenizer = SimpleTokenizer()
#text_tokens = [tokenizer.encode(desc) for desc in texts]
# text_tokens = []
# for desc in texts:
#     print(desc)
#     text_tokens.append(tokenizer.encoder(desc))