In [7]:
%reset -f

In [2]:
import os
from PIL import Image as PILImage
from tqdm import tqdm
from IPython.display import display
from datasets import Dataset
import torch
from transformers import AutoTokenizer
from transformers import CLIPProcessor
from torchvision import transforms

In [18]:
raw_data_path = 'D:/Adams/dataset/CUB_200_2011_CAP'
save_dir = 'D:/Adams/CUB_200_2011_Encoded/'

In [16]:

def raw_data_gen(raw_data_path):
    max_token_length = 10
    eos_token_id = 49407
    clip_model_name = "openai/clip-vit-base-patch32"
    roberta_model_name = "xlm-roberta-base"
    roberta_tokenizer = AutoTokenizer.from_pretrained(roberta_model_name)
    clip_processor = CLIPProcessor.from_pretrained(clip_model_name)

    for img_file in tqdm(os.listdir(raw_data_path)):
        if img_file.endswith(".png") or img_file.endswith(".jpg"):
            metadata_file = img_file.replace(".png", ".txt").replace(".jpg", ".txt")
            if os.path.exists(os.path.join(raw_data_path, metadata_file)):
                action_id = img_file.replace(".png", "").replace(".jpg", "")
                wds = action_id.split("_")[:-2]
                label = ' '.join(wds)

                pil_image = PILImage.open(os.path.join(raw_data_path, img_file))
                with open(os.path.join(raw_data_path, metadata_file), 'r') as f:
                    text = f.read().replace('\n', '')

                inputs = clip_processor(text=[label], images=pil_image, return_tensors="pt", padding=True, truncation=True, max_length=10)
                encoded_text = roberta_tokenizer(text, padding='max_length', truncation=True, max_length=512, return_tensors="pt")

                item = {}
                item['input_ids'] = torch.cat((inputs['input_ids'][0],torch.full((max_token_length-len(inputs['input_ids'][0]),), eos_token_id)), dim=0)
                item['attention_mask'] = torch.cat((inputs['attention_mask'][0],torch.zeros((max_token_length-len(inputs['attention_mask'][0]),))), dim=0)
                item['pixel_values'] = inputs['pixel_values'][0]                
                item['encoded_metadata'] = encoded_text['input_ids'][0]
                item['metadata_attention_mask'] = encoded_text['attention_mask'][0]
                item['labels'] = item['input_ids']
                yield item

In [17]:
ds = Dataset.from_generator(lambda: raw_data_gen(raw_data_path))

100%|██████████| 23574/23574 [01:31<00:00, 257.28it/s]examples/s]
Generating train split: 11787 examples [01:33, 126.10 examples/s]


In [None]:
ds[0]
#ds.cache_files

In [20]:
ds.save_to_disk(save_dir)

Saving the dataset (15/15 shards): 100%|██████████| 11787/11787 [00:07<00:00, 1616.67 examples/s]


In [21]:
ds.push_to_hub("weiywang/CUB_200_2011_Encoded")

Creating parquet from Arrow format: 100%|██████████| 1/1 [00:02<00:00,  2.63s/ba]
Uploading the dataset shards:   7%|▋         | 1/15 [00:35<08:14, 35.29s/it]