In [14]:
import os
import sys
import argparse
import logging
import json
from typing import Optional, Dict, Any

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, DistributedSampler
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import (
    ShardingStrategy
)



class StreamingCC12MDataset(torch.utils.data.IterableDataset):
    def __init__(
        self, 
        dataset, 
        siglip_processor, 
        sonar_text_encoder,
        max_samples=None
    ):
        """
        Streaming dataset wrapper
        
        Args:
            dataset: Hugging Face streaming dataset
            siglip_processor: SIG-LIP image processor
            sonar_text_encoder: SONAR text embedding model
            max_samples: Optional limit on number of samples
        """
        self.dataset = dataset
        self.siglip_processor = siglip_processor
        self.sonar_text_encoder = sonar_text_encoder
        self.max_samples = max_samples
        
        # Iterator for streaming
        self.dataset_iter = iter(self.dataset)
    
    def __iter__(self):
        """
        Efficient iterator for streaming dataset
        """
        # Track sample count
        sample_count = 0
        
        for item in self.dataset_iter:
            # Optional sample limit
            if self.max_samples and sample_count >= self.max_samples:
                break
            
            try:
                # Process image
                image = Image.open(io.BytesIO(item['jpg'])).convert('RGB')
                
                # SIG-LIP image processing
                image_inputs = self.siglip_processor(
                    images=[image], 
                    return_tensors="pt"
                )
                
                # Extract visual embedding
                with torch.no_grad():
                    visual_embedding = model.get_image_features(**image_inputs)
                
                # Extract text 
                text = item['txt'].decode('utf-8')
                
                # SONAR text embedding
                concept_embedding = self.sonar_text_encoder.predict(
                    [text], 
                    source_lang="eng_Latn"
                )
                
                yield (
                    visual_embedding.squeeze(), 
                    torch.tensor(concept_embedding).squeeze()
                )
                
                sample_count += 1
            
            except Exception as e:
                logging.error(f"Error processing item: {e}")
                continue

def create_streaming_dataloader(
    dataset, 
    siglip_processor, 
    sonar_text_encoder,
    config
):
    """
    Create a dataloader for streaming dataset
    
    Args:
        dataset: Streaming Hugging Face dataset
        siglip_processor: SIG-LIP image processor
        sonar_text_encoder: SONAR text embedding model
        config: Training configuration
    
    Returns:
        DataLoader for streaming dataset
    """
    # Create streaming dataset wrapper
    streaming_dataset = StreamingCC12MDataset(
        dataset,
        siglip_processor=siglip_processor,
        sonar_text_encoder=sonar_text_encoder,
        max_samples=config.get('max_samples')
    )
    
    # Create dataloader
    dataloader = torch.utils.data.DataLoader(
        streaming_dataset,
        batch_size=config.get('batch_size', 64),
        num_workers=config.get('num_workers', 4),
        pin_memory=True
    )
    
    return dataloader


from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline
import transformers

t2vec_model = TextToEmbeddingModelPipeline(
    encoder="text_sonar_basic_encoder",
    tokenizer="text_sonar_basic_encoder"
)

ckpt = "google/siglip2-so400m-patch14-384"
siglip_model = transformers.AutoModel.from_pretrained(
    ckpt, device_map="auto"
).eval()
processor = transformers.AutoProcessor.from_pretrained(ckpt)
        
# In main training script
dataset = load_dataset(
    "pixparse/cc12m-wds",
    cache_dir='/gpfs/gibbs/project/hartley/tjb76/Sig-CIP/Datasets',
    streaming=True,
    split='train'
)

train_loader = create_streaming_dataloader(
    dataset,
    siglip_processor=processor,
    sonar_text_encoder=t2vec_model,
    config=training_config
)

ModuleNotFoundError: No module named 'sonar'

In [2]:
import transformers
ckpt = 'google/siglip2-so400m-patch14-384'
siglip_model = transformers.AutoModel.from_pretrained(
    ckpt, device_map="auto"
).eval()
processor = transformers.AutoProcessor.from_pretrained(ckpt)

2025-04-08 13:34:58.053572: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-04-08 13:34:58.067072: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1744133698.082177 3122860 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1744133698.086667 3122860 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-04-08 13:34:58.102425: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instr

TypeError: expected str, bytes or os.PathLike object, not NoneType

In [3]:
from transformers import pipeline

# load pipeline
ckpt = "google/siglip2-so400m-patch14-384"
image_classifier = pipeline(model=ckpt, task="zero-shot-image-classification")


tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.48, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
Device set to use cuda:0


In [8]:
from transformers import SiglipVisionModel, SiglipImageProcessor, SiglipModel, AutoImageProcessor, AutoProcessor

# Load only the vision components
vision_model = SiglipVisionModel.from_pretrained("google/siglip2-base-patch16-224", device_map="auto").eval()
image_processor = SiglipImageProcessor.from_pretrained("google/siglip2-base-patch16-224")

model = SiglipModel.from_pretrained("google/siglip2-base-patch16-224")
