In [62]:
from datasets import load_dataset
from PIL import Image
import torch
from torchvision import transforms
from transformers import AutoTokenizer, CLIPTextModel
from torch.utils.data import Dataset, DataLoader
import urllib
import io
import os
from tqdm.auto import tqdm

os.environ["TOKENIZERS_PARALLELISM"] = "false"


path_to_cache = "/scratch/bbug/priyamm2/huggingface_cache"
dataset_name = "conceptual_captions"

### Load Dataset ###
dataset = load_dataset(dataset_name, cache_dir=path_to_cache)
dataset = dataset["train"]

### Load Tokenizer ###
tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32", cache_dir=path_to_cache)

In [None]:
class ConceptualCpationsDataset(Dataset):
    def __init__(self, hf_dataset, img_size=128, num_retries=2):
        self.hf_dataset = hf_dataset
        self.num_retries = num_retries
        
        self.image2tensor = transforms.Compose(
            
            [
                transforms.Resize((img_size, img_size)),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(), 
                transforms.Lambda(lambda t: (t*2) - 1)
            ]
        )
    
    def __len__(self):
        return len(self.hf_dataset)
    
    def __getitem__(self, idx):
        item = self.hf_dataset[idx]
        
        image_url = item["image_url"]
        caption = item["caption"]
        
        ### Download Image ###
        for _ in range(self.num_retries):
            
            try:
                ### Create Request for Image ###
                request = urllib.request.Request(image_url)

                ### Open Image ###
                with urllib.request.urlopen(request, timeout=1) as req:
                    image = Image.open(io.BytesIO(req.read())).convert("RGB")
                
                break
            
            except:
                
                image = None
        
        if image is not None:
            
            image = self.image2tensor(image)
        
        return image, caption
    

def collate_fn(batch):
    
    images, captions = [], []
    
    ### Loop through all images and keep non-none ###
    for image, caption in batch:
        if image is not None:
            images.append(image.unsqueeze(0))
            captions.append(caption)
    
    ### Stack Images ###
    
    if len(images) > 1:
        images = torch.concatenate(images)
    else:
        images = None
    
    ### Tokenize and Stack Captions ###
    annotation = tokenizer(captions, padding=True, return_tensors="pt")
    
    batch = {"images": images, 
             "context": annotation["input_ids"], 
             "mask": ~annotation["attention_mask"].bool()}
    
    return batch
    
    
    
ccd = ConceptualCpationsDataset(hf_dataset=dataset)
trainloader = DataLoader(ccd, batch_size=16, collate_fn=collate_fn, num_workers=4, shuffle=True, pin_memory=True)
    
for data in tqdm(trainloader):
    print(data["images"].shape)
    pass
    
        
            
                
    

  0%|          | 0/207396 [00:01<?, ?it/s]

torch.Size([11, 3, 128, 128])
torch.Size([9, 3, 128, 128])
torch.Size([14, 3, 128, 128])
torch.Size([13, 3, 128, 128])
torch.Size([10, 3, 128, 128])
torch.Size([11, 3, 128, 128])
torch.Size([10, 3, 128, 128])
torch.Size([10, 3, 128, 128])
torch.Size([9, 3, 128, 128])
torch.Size([11, 3, 128, 128])
torch.Size([10, 3, 128, 128])
torch.Size([12, 3, 128, 128])
torch.Size([11, 3, 128, 128])
torch.Size([9, 3, 128, 128])
torch.Size([12, 3, 128, 128])
torch.Size([10, 3, 128, 128])
torch.Size([10, 3, 128, 128])
torch.Size([15, 3, 128, 128])
torch.Size([13, 3, 128, 128])
torch.Size([8, 3, 128, 128])
torch.Size([10, 3, 128, 128])
torch.Size([8, 3, 128, 128])
torch.Size([10, 3, 128, 128])
torch.Size([8, 3, 128, 128])
torch.Size([10, 3, 128, 128])
torch.Size([10, 3, 128, 128])
torch.Size([13, 3, 128, 128])
torch.Size([12, 3, 128, 128])
torch.Size([12, 3, 128, 128])
torch.Size([9, 3, 128, 128])
torch.Size([12, 3, 128, 128])
torch.Size([9, 3, 128, 128])
torch.Size([6, 3, 128, 128])
torch.Size([12, 3, 