In [None]:
pip install datasets

#Imports

In [None]:
import os
import urllib
import io
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from PIL import Image as pimage
from transformers import CLIPModel, CLIPProcessor
from datasets import load_dataset
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm

# Configuration

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 1024
LAION_SPLIT = "train[:1%]"  # Adjust to load a subset (1%)

# --- Initialize CLIP Model ---
print("Loading CLIP model...")
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(DEVICE)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# Helper Functions

In [None]:
def download_image(url, timeout=5):
    """
    Downloads an image from a URL with a custom User-Agent.
    Args:
        url: Image URL.
        timeout: Max time in seconds for download.

    Returns:
        img_stream: BytesIO stream of the image.
    """
    try:
        urllib_request = urllib.request.Request(
            url,
            data=None,
            headers={"User-Agent": "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:72.0) Gecko/20100101 Firefox/72.0"},
        )
        with urllib.request.urlopen(urllib_request, timeout=timeout) as r:
            img_stream = io.BytesIO(r.read())
        return img_stream
    except Exception as e:
        print(f"Failed to download image from {url}: {e}")
        return None

def get_text_emb(text):
    """
    Generates normalized text embeddings using CLIP.
    Args:
        text: Input text (string).

    Returns:
        Normalized text embedding as a NumPy array.
    """
    try:
        with torch.no_grad():
            inputs = processor(text=[text], return_tensors="pt").to(DEVICE)
            text_emb = model.get_text_features(**inputs)
            text_emb /= text_emb.norm(dim=-1, keepdim=True)
            text_emb = text_emb.cpu().detach().numpy().astype("float32")[0]
        return text_emb
    except Exception as e:
        print(f"Error generating text embedding: {e}")
        return None

def get_image_emb(image_url):
    """
    Generates normalized image embeddings using CLIP.
    Args:
        image_url: URL of the image.

    Returns:
        Normalized image embedding as a NumPy array.
    """
    try:
        img_stream = download_image(image_url)
        if img_stream is None:
            raise Exception("Image download failed.")

        image = pimage.open(img_stream).convert("RGB")
        with torch.no_grad():
            inputs = processor(images=image, return_tensors="pt").to(DEVICE)
            image_emb = model.get_image_features(**inputs)
            image_emb /= image_emb.norm(dim=-1, keepdim=True)
            image_emb = image_emb.cpu().detach().numpy().astype("float32")[0]
        return image_emb
    except Exception as e:
        return None

# --- Dataset Class ---
class LAIONDataset(Dataset):
    """
    Custom dataset for LAION data containing image URLs and captions.
    """
    def __init__(self, dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        sample = self.dataset[idx]
        image_url = sample["image_url"]
        caption = sample["caption"]
        return image_url, caption

# Load LAION

In [None]:
print("Loading LAION dataset...")
laion_dataset = load_dataset("conceptual_captions", split=LAION_SPLIT)
laion_train_dataset = LAIONDataset(laion_dataset)
laion_dataloader = DataLoader(laion_train_dataset, batch_size=BATCH_SIZE, shuffle=True)

# Extract Embeddings

In [None]:
image_embeddings = []
text_embeddings = []

def process_single_example(image_url, caption):
    """
    Process a single image and caption pair to extract embeddings.
    Returns:
        (image_emb, text_emb) tuple or (None, None) if image fails.
    """
    try:
        image_emb = get_image_emb(image_url)
        if image_emb is not None:
            text_emb = get_text_emb(caption)
            return image_emb, text_emb
        return None, None
    except Exception as e:
        print(f"Error processing {image_url}: {e}")
        return None, None

print("Extracting embeddings...")
with ThreadPoolExecutor(max_workers=8) as executor:  # Adjust max_workers based on CPU cores
    for batch_idx, (image_urls, captions) in enumerate(tqdm(laion_dataloader, desc="Processing Batches")):
        batch_image_embeds = []
        batch_text_embeds = []
        futures = []

        # Submit tasks to the ThreadPoolExecutor
        for image_url, caption in zip(image_urls, captions):
            futures.append(executor.submit(process_single_example, image_url, caption))

        # Collect results as they complete
        for future in tqdm(as_completed(futures), total=len(futures), desc=f"Batch {batch_idx + 1}", leave=False):
            image_emb, text_emb = future.result()
            if image_emb is not None and text_emb is not None:
                batch_image_embeds.append(image_emb)
                batch_text_embeds.append(text_emb)

        # Append valid embeddings
        image_embeddings.extend(batch_image_embeds)
        text_embeddings.extend(batch_text_embeds)

        print(f"Batch {batch_idx + 1}: Processed {len(batch_image_embeds)} images and {len(batch_text_embeds)} captions")

# Convert embeddings to NumPy arrays
image_embeddings = np.array(image_embeddings)
text_embeddings = np.array(text_embeddings)

# Optional: Save embeddings to disk
# np.save("image_embeddings.npy", image_embeddings)
# np.save("text_embeddings.npy", text_embeddings)

print(f"Extracted {len(image_embeddings)} image embeddings and {len(text_embeddings)} text embeddings.")
