In [1]:
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
from PIL import Image
import requests
from typing import List, Union, Dict
import torch
from pathlib import Path
import logging
import gc
import psutil

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def log_memory_usage():
    process = psutil.Process(os.getpid())
    gpu_memory = f"GPU Memory: {torch.cuda.memory_allocated()/1024**3:.2f}GB allocated, {torch.cuda.memory_reserved()/1024**3:.2f}GB reserved"
    ram_memory = f"RAM Memory: {process.memory_info().rss/1024**3:.2f}GB"
    logger.info(f"{gpu_memory} | {ram_memory}")

class MolmoBatchProcessor:
    def __init__(
        self,
        model_name: str = 'allenai/Molmo-7B-D-0924',
        device: str = None,
        torch_dtype: torch.dtype = None
    ):
        if device is None:
            device = 'cuda' if torch.cuda.is_available() else 'cpu'
        if torch_dtype is None:
            torch_dtype = torch.float16 if device == 'cuda' else torch.float32
            
        self.device = device
        self.torch_dtype = torch_dtype
        
        logger.info(f"Using device: {device}, dtype: {torch_dtype}")
        log_memory_usage()
        
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.set_per_process_memory_fraction(0.95)
        
        with torch.no_grad():
            self.processor = AutoProcessor.from_pretrained(
                model_name,
                trust_remote_code=True
            )
        
        torch.cuda.empty_cache()
        gc.collect()
        
        with torch.no_grad():
            self.model = AutoModelForCausalLM.from_pretrained(
                model_name,
                trust_remote_code=True,
                device_map="auto",
                torch_dtype=torch_dtype,
                max_memory={0: "35GiB"},
                offload_folder="offload",
                offload_state_dict=True
            )
            self.model.eval()
        
        logger.info("Model and processor loaded successfully")
        log_memory_usage()

    def load_image(self, image_source: Union[str, Path, Image.Image]) -> Image.Image:
        if isinstance(image_source, Image.Image):
            return image_source
        elif isinstance(image_source, (str, Path)):
            if str(image_source).startswith(('http://', 'https://')):
                return Image.open(requests.get(image_source, stream=True).raw)
            else:
                return Image.open(image_source)
        else:
            raise ValueError("Unsupported image source type")

    def fix_tensor_dimensions(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """Fix tensor dimensions to match model requirements."""
        batch_size = inputs['input_ids'].size(0)
        num_images = inputs['images'].size(1) if len(inputs['images'].shape) > 3 else inputs['images'].size(0)
        
        processed = {}
        for k, v in inputs.items():
            if isinstance(v, torch.Tensor):
                if k == 'images':
                    if len(v.shape) == 3:  # [T, N, D]
                        v = v.unsqueeze(0)  # [B, T, N, D]
                elif k == 'image_input_idx':
                    if len(v.shape) == 2:  # [num_images, num_patches]
                        v = v.unsqueeze(0)  # Add batch dimension
                elif k == 'image_masks':
                    if len(v.shape) == 2:  # [num_images, num_patches]
                        v = v.unsqueeze(0)  # Add batch dimension
                elif len(v.shape) == 1:
                    v = v.unsqueeze(0)
                processed[k] = v.to(self.device, non_blocking=True)
            else:
                processed[k] = v
                
        return processed

    @torch.no_grad()
    def process_single_item(self, image: Image.Image, prompt: str) -> Dict[str, torch.Tensor]:
        """Process a single image-prompt pair with memory cleanup."""
        with torch.amp.autocast('cuda'):
            inputs = self.processor.process(
                images=[image],
                text=prompt
            )
            
            # Fix tensor dimensions and move to device
            inputs = self.fix_tensor_dimensions(inputs)
            
            logger.info(f"Input shapes after processing:")
            for k, v in inputs.items():
                if isinstance(v, torch.Tensor):
                    logger.info(f"{k}: {v.shape}")
                    
            torch.cuda.empty_cache()
            return inputs

    @torch.no_grad()
    def process_batch(
        self,
        image_sources: List[Union[str, Path, Image.Image]],
        prompts: List[str],
        batch_size: int = 1,
        max_new_tokens: int = 200,
        **generation_kwargs
    ) -> List[str]:
        if len(image_sources) != len(prompts):
            raise ValueError("Number of images must match number of prompts")

        results = []
        for i in range(0, len(image_sources), batch_size):
            logger.info(f"\nProcessing batch {i//batch_size + 1}/{(len(image_sources)-1)//batch_size + 1}")
            log_memory_usage()
            
            try:
                batch_images = image_sources[i:i + batch_size]
                batch_prompts = prompts[i:i + batch_size]
                
                with torch.amp.autocast('cuda'):
                    img = self.load_image(batch_images[0])
                    inputs = self.process_single_item(img, batch_prompts[0])
                    del img
                    
                    logger.info("Input shapes before generation:")
                    for k, v in inputs.items():
                        if isinstance(v, torch.Tensor):
                            logger.info(f"{k}: {v.shape}")
                    
                    generation_config = GenerationConfig(
                        max_new_tokens=max_new_tokens,
                        **generation_kwargs
                    )
                    
                    with torch.inference_mode():
                        outputs = self.model.generate_from_batch(
                            inputs,
                            generation_config,
                            tokenizer=self.processor.tokenizer
                        )
                        
                        logger.info(f"Output shape: {outputs.shape if isinstance(outputs, torch.Tensor) else [o.shape for o in outputs]}")

                    if isinstance(outputs, torch.Tensor):
                        generated_tokens = outputs[0, inputs['input_ids'].size(1):].cpu()
                    else:
                        generated_tokens = outputs[0][inputs['input_ids'].size(1):].cpu()
                        
                    generated_text = self.processor.tokenizer.decode(
                        generated_tokens,
                        skip_special_tokens=True
                    )
                    results.append(generated_text)
                    
                    del inputs, outputs, generated_tokens
                    torch.cuda.empty_cache()
                    gc.collect()
                    log_memory_usage()
                
            except Exception as e:
                logger.error(f"Error processing batch: {e}")
                logger.exception("Full traceback:")
                results.extend([None] * len(batch_prompts))
                torch.cuda.empty_cache()
                gc.collect()
                
        return results

if __name__ == "__main__":
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    
    with torch.no_grad():
        processor = MolmoBatchProcessor(
            device='cuda',
            torch_dtype=torch.float16
        )
        
        images = [
            "https://picsum.photos/id/237/536/354",
            "https://picsum.photos/id/238/536/354",
            "https://img.freepik.com/free-photo/view-wild-lion-nature_23-2150460851.jpg",
        ]
        
        prompts = [
            "Describe this image.",
            "What do you see in this image?",
            "Analyze this image in detail.",
        ]
        
        results = processor.process_batch(
            image_sources=images,
            prompts=prompts,
            batch_size=1,
            max_new_tokens=200,
            temperature=0.7,
            do_sample=True
        )
        
        for i, (image, prompt, result) in enumerate(zip(images, prompts, results)):
            print(f"\nBatch item {i+1}:")
            print(f"Image: {image}")
            print(f"Prompt: {prompt}")
            print(f"Generated text: {result}")

  from .autonotebook import tqdm as notebook_tqdm
INFO:__main__:Using device: cuda, dtype: torch.float16
INFO:__main__:GPU Memory: 0.00GB allocated, 0.00GB reserved | RAM Memory: 0.49GB
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:06<00:00,  1.09it/s]
INFO:__main__:Model and processor loaded successfully
INFO:__main__:GPU Memory: 14.94GB allocated, 14.96GB reserved | RAM Memory: 1.58GB
INFO:__main__:
Processing batch 1/3
INFO:__main__:GPU Memory: 14.94GB allocated, 14.96GB reserved | RAM Memory: 1.58GB
INFO:__main__:Input shapes after processing:
INFO:__main__:input_ids: torch.Size([1, 589])
INFO:__main__:images: torch.Size([1, 5, 576, 588])
INFO:__main__:image_input_idx: torch.Size([1, 5, 144])
INFO:__main__:image_masks: torch.Size([1, 5, 576])
INFO:__main__:Input shapes before generation:
INFO:__main__:input_ids: torch.Size([1, 589])
INF


Batch item 1:
Image: https://picsum.photos/id/237/536/354
Prompt: Describe this image.
Generated text:  The image captures a black Labrador puppy sitting on an aged wooden deck. The puppy, looking up towards the camera with large, expressive eyes and a black nose, has floppy ears and a smooth, shiny black coat. Its posture, with front paws tucked under its chin, conveys a sense of curiosity and eagerness. The deck, appearing to be made of natural wood, is weathered and slightly dirty with visible cracks between the planks. 
The lighting in the photograph comes from above, casting subtle shadows that enhance the textures of both the puppy's fur and the wood. The bottom corners of the image are slightly darker, framing the scene and drawing attention to the puppy's adorable face. The overall composition captures a moment of innocent anticipation as the puppy gazes up at the viewer, creating a warm and endearing portrait. 
This detailed view emphasizes the puppy's youthful features and t

In [None]:
okay but its not doing b