In [53]:
import h5py
from pathlib import Path
import pandas as pd
import io 
import torch
from PIL import Image
import numpy as np

class ConceptualCaptionsDataset:
    def __init__(self, root, processor=None, transform=None, use_llava_split=False):
        self.root = Path(root)
        if not self.root.is_dir():
            raise ValueError("Root must be a dir.")
        
        if use_llava_split:
            self.mapper = pd.read_parquet('/home/data/CC3M_LLaVA/mapper_LLaVA_clean.parquet')
        else:
            self.mapper = pd.read_parquet(self.root / 'mapper.parquet')
        self.linker = h5py.File(self.root / 'linker.h5', 'r')
        
        self.transform = transform
        self.processor = processor
        
    def __len__(self):
        return len(self.mapper)
    
    def __getitem__(self, idx):
        dp = self.mapper.iloc[idx]
        
        raw_image = self.linker.get(f"CC3m_{dp.shard}").get('images')[dp.h5_index]
        image = Image.open(io.BytesIO(raw_image))
        caption = dp.caption
        
        if self.processor is not None:
            inputs = self.processor(images=image, text=caption, padding='max_length', return_tensors="pt")
            pixel_values = torch.squeeze(inputs['pixel_values'])
            if self.transform is not None:
                pixel_values = self.transform(pixel_values)
            input_ids = torch.squeeze(inputs['input_ids'])
            token_type_ids = torch.squeeze(inputs['token_type_ids'])
            attention_mask = torch.squeeze(inputs['attention_mask'])

            return pixel_values, input_ids, token_type_ids, attention_mask
        else:
            if self.transform is not None:
                image = self.transform(image)
            return image, caption


# from transformers import VisionTextDualEncoderProcessor, AutoImageProcessor, AutoTokenizer

# processor = VisionTextDualEncoderProcessor(
#         image_processor=AutoImageProcessor.from_pretrained('google/vit-base-patch16-224'), 
#         tokenizer=AutoTokenizer.from_pretrained('google-bert/bert-base-uncased', use_fast=False, max_length=70)
# )

# dataset = ConceptualCaptionsDataset('/home/data/mmssl/CC3m', use_llava_split=True, processor=processor)
# mapper = dataset.mapper
#llava_split = dataset.llava_split

In [57]:
#from torchvision.transforms.functional import pil_to_tensor
import warnings

dataset = ConceptualCaptionsDataset('/home/data/mmssl/CC3m', use_llava_split=True)

with open('indices_large_2.txt', 'w') as f1:
    for idx in range(len(dataset)):
        if idx % 10000 == 0: 
            print(f"Verarbeite Index {idx}")
        try:
            with warnings.catch_warnings():
                warnings.simplefilter('error', Image.DecompressionBombWarning)
                image, _ = dataset[idx]
                # Weitere Bildverarbeitung hier, falls erforderlich
        except Image.DecompressionBombWarning:
            print(f"Warnung: Bild bei Index {idx} ist potenziell zu groß und könnte das System überlasten.")
            f1.write(f"{idx}\n")
            continue
        except Exception as e:
            print(f"Ein anderer Fehler bei Index {idx}: {e}")
            continue

Verarbeite Index 0
Verarbeite Index 10000
Verarbeite Index 20000
Verarbeite Index 30000
Verarbeite Index 40000
Warnung: Bild bei Index 42195 ist potenziell zu groß und könnte das System überlasten.
Verarbeite Index 50000
Verarbeite Index 60000
Verarbeite Index 70000
Verarbeite Index 80000
Verarbeite Index 90000
Warnung: Bild bei Index 99616 ist potenziell zu groß und könnte das System überlasten.
Verarbeite Index 100000
Verarbeite Index 110000
Verarbeite Index 120000
Verarbeite Index 130000
Verarbeite Index 140000
Verarbeite Index 150000
Verarbeite Index 160000
Verarbeite Index 170000
Verarbeite Index 180000
Verarbeite Index 190000
Verarbeite Index 200000
Verarbeite Index 210000
Verarbeite Index 220000
Verarbeite Index 230000
Verarbeite Index 240000
Verarbeite Index 250000
Verarbeite Index 260000
Verarbeite Index 270000
Verarbeite Index 280000
Verarbeite Index 290000
Verarbeite Index 300000
Verarbeite Index 310000
Verarbeite Index 320000
Verarbeite Index 330000
Verarbeite Index 340000


In [58]:
df = pd.read_parquet('/home/data/CC3M_LLaVA/mapper_LLaVA_clean.parquet')
print(len(df))
with open('indices_large_2.txt', 'r') as f:
    indices_to_drop = [int(line.strip()) for line in f]
    print(len(indices_to_drop))
new_df = df.drop(df.index[indices_to_drop])
print(len(new_df))
new_df.to_parquet('/home/data/CC3M_LLaVA/mapper_LLaVA_clean_2.parquet')

472988
2
472986


In [49]:
dataset = ConceptualCaptionsDataset('/home/data/mmssl/CC3m', use_llava_split=True, processor=processor)

In [50]:
from torch.utils.data import DataLoader

dataloader = DataLoader(dataset, batch_size=192, shuffle=False, num_workers=12)

In [52]:
import time

start_time = time.time()
for batch in dataloader:
    pass
end_time = time.time()
print(f"Time: {end_time - start_time}")



KeyboardInterrupt: 

In [None]:
import time

missing_images = []
multiple_images = []
ok_images = 0
search_times = []
idx_count = 0

# start_time = time.time()
for idx, image_name in enumerate(llava_split["id"]):
    image_id = image_name.split('_')[2]
    start_time = time.time()
    image_map = mapper[mapper["key"] == image_id]
    end_time = time.time()
    search_times.append(end_time - start_time)
    #break
    if image_map.shape[0] == 0:
        missing_images.append(image_id)
    elif image_map.shape[0] > 1:
        multiple_images.append(image_id)
    else: ok_images += 1
    idx_count += 1
    # if idx == 300: break

# end_time = time.time()

In [44]:
print(f"elapsed time: {end_time - start_time}")
print(f"average time for lookup: {(end_time - start_time)/300}")

elapsed time: 204.64686965942383
average time for lookup: 0.6821562321980794


In [None]:
print(f"number of lookups: {len(search_times)}")
print(f"average time: {sum(search_times) / len(search_times)}")
search_times

In [55]:
new_mapper = pd.read_parquet('/home/data/CC3M_LLaVA/mapper_LLaVA.parquet')

In [59]:
print(len(llava_split))
print(len(new_mapper))

595375
477329


In [48]:
import time

missing_images = []
multiple_images = []
ok_images = 0
search_times = []

start_time = time.time()
for idx in mapper.index:
    dp = mapper.iloc[idx]
    if idx == 300: break
end_time = time.time()

print(f"elapsed time: {end_time - start_time}")
print(f"average iter duration: {(end_time - start_time) / 300}")

elapsed time: 0.060390472412109375
average iter duration: 0.00020130157470703126
