In [1]:
import torch
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
from PIL import Image
import requests
from torch.nn import functional as F

# Check for GPU availability and set the device
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
if device.type == 'cuda':
    print(f"Using GPU: {torch.cuda.get_device_name(0)}")
else:
    print("CUDA is not available. Please ensure you have a compatible GPU and the necessary drivers installed.")

# Load the processor
processor = AutoProcessor.from_pretrained(
    'allenai/Molmo-7B-D-0924',
    trust_remote_code=True,
)

# Load the model onto the GPU with half-precision
model = AutoModelForCausalLM.from_pretrained(
    'allenai/Molmo-7B-D-0924',
    trust_remote_code=True,
    torch_dtype=torch.float32,
).to(device)
model.eval()

# Define your batch of images and texts
image_urls = [
    "https://picsum.photos/id/237/536/354",
    "https://picsum.photos/id/238/536/354",
    "https://picsum.photos/id/239/536/354"
]

# Load and preprocess images
images = []
for url in image_urls:
    # Load the image
    image = Image.open(requests.get(url, stream=True).raw).convert('RGB')
    # Resize the image to reduce memory usage
    image = image.resize((224, 224))  # Adjust dimensions as needed
    images.append(image)

texts = [
    "Describe this image 1.",
    "Describe this image 2.",
    "Describe this image 3."
]

# Process each text and image individually
processed_inputs = []
for idx, (text, image) in enumerate(zip(texts, images)):
    input_data = processor.process(
        images=image,
        text=text
    )
    # Move inputs to the GPU and set dtypes appropriately
    for k, v in input_data.items():
        if isinstance(v, torch.Tensor):
            v = v.to(device)
            if v.dtype in [torch.float16, torch.float32]:
                v = v.float()  # Ensure FP32 for floating-point tensors
            input_data[k] = v
            print(f"Input {k} device after to(device): {v.device}, dtype: {v.dtype}")
    processed_inputs.append(input_data)

# Verify that tensors have correct dtypes
for idx, input_data in enumerate(processed_inputs):
    for k, v in input_data.items():
        if isinstance(v, torch.Tensor):
            print(f"Processed input {idx}, tensor {k} dtype: {v.dtype}")
            if k in ['input_ids', 'attention_mask', 'position_ids', 'image_input_idx']:
                assert v.dtype in [torch.int32, torch.int64], f"Tensor {k} should be integer type but is {v.dtype}"
            elif k in ['images', 'image_masks']:
                assert v.dtype == torch.float32, f"Tensor {k} should be float32 but is {v.dtype}"

# Stack the inputs to create batched tensors
batched_inputs = {}
for key in processed_inputs[0].keys():
    if isinstance(processed_inputs[0][key], torch.Tensor):
        tensors_to_stack = [input_data[key] for input_data in processed_inputs]
        # Verify that all tensors to stack are on the GPU
        devices = [t.device for t in tensors_to_stack]
        print(f"Devices for {key} before stacking: {devices}")
        assert all(d == device for d in devices), f"Not all tensors for {key} are on the GPU."
        batched_inputs[key] = torch.stack(tensors_to_stack, dim=0)
        # Verify that the batched tensor is on the GPU
        print(f"Batched input {key} device after stacking: {batched_inputs[key].device}")
        assert batched_inputs[key].device == device, f"Batched input {key} is not on the GPU."
    else:
        # For non-tensor data, collect in a list
        batched_inputs[key] = [input_data[key] for input_data in processed_inputs]

# Define generation configuration
generation_config = GenerationConfig(max_new_tokens=200)

# ============================================================
# Suggestion 1: Verify Model Output Without Iterative Generation
# ============================================================

# Generate outputs using the standard generate_from_batch method
with torch.no_grad():
    outputs = model.generate_from_batch(
        batched_inputs,
        generation_config=generation_config,
        tokenizer=processor.tokenizer
    )

# Calculate the effective lengths of the inputs for each batch item
input_lengths = (batched_inputs['input_ids'] != processor.tokenizer.pad_token_id).sum(dim=1)

# Iterate over each item in the batch to extract and decode the generated tokens
print("\nOutputs using generate_from_batch:")
for i in range(len(texts)):
    # Slice the output to get only the generated tokens for this batch item
    generated_tokens = outputs[i, input_lengths[i]:]
    # Decode the tokens to text
    generated_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
    print(f"Generated text for input {i+1}: {generated_text}")

# ============================================================
# Implementing Iterative Generation with Adjustments
# ============================================================

# Initialize variables for iterative generation
input_ids = batched_inputs['input_ids']
batch_size, seq_len = input_ids.shape
print(f"Initial input_ids shape: {input_ids.shape}")

# Prepare attention mask
attention_mask = batched_inputs.get('attention_mask')
if attention_mask is None:
    attention_mask = input_ids != processor.tokenizer.pad_token_id
print(f"Initial attention_mask shape: {attention_mask.shape}")

use_position_ids = getattr(model.config, 'use_position_ids', False)
if use_position_ids:
    # Get the special token IDs for image tokens directly using their string representations
    image_token_ids = [
        processor.tokenizer.convert_tokens_to_ids(token)
        for token in ["<im_start>", "<im_end>", "<im_patch>", "<im_col>", "<|image|>"]
        if token in processor.tokenizer.get_vocab()
    ]

    # Identify positions of image tokens
    image_token_mask = torch.zeros_like(input_ids, dtype=torch.bool)
    for token_id in image_token_ids:
        image_token_mask |= (input_ids == token_id)

    # Compute position_ids, assigning fixed position_ids to image tokens (e.g., 0)
    position_ids = torch.zeros_like(input_ids, dtype=torch.long)
    text_token_mask = ~image_token_mask & attention_mask.bool()

    # Assign incremental position_ids to text tokens
    position_ids[text_token_mask] = (torch.cumsum(text_token_mask.to(torch.long), dim=-1) - 1)[text_token_mask]
    print(f"Initial position_ids shape: {position_ids.shape}")
else:
    position_ids = None

# Extract image-related inputs
images = batched_inputs.get('images')
image_masks = batched_inputs.get('image_masks')
image_input_idx = batched_inputs.get('image_input_idx')
append_last_valid_logits = batched_inputs.get('append_last_valid_logits')

# Initialize generated sequences
generated_sequences = input_ids.clone()

# Initialize 'done' flags
done = torch.zeros(batch_size, dtype=torch.bool, device=input_ids.device)

# Set maximum number of new tokens to generate
max_new_tokens = generation_config.max_new_tokens

# Set temperature and top_p values
temperature = 1.0
top_p = 0.9

# Initialize past_key_values
past_key_values = None

with torch.no_grad():
    for step in range(max_new_tokens):
        print(f"\n=== Generation Step {step} ===")

        if past_key_values is None:
            # First step
            model_inputs = {
                'input_ids': input_ids,
                'attention_mask': attention_mask,
                'position_ids': position_ids,
                'images': images,
                'image_masks': image_masks,
                'image_input_idx': image_input_idx,
                'append_last_valid_logits': append_last_valid_logits,
                'use_cache': True,
            }
        else:
            # Subsequent steps
            # Pass only the last position ID incremented by one
            last_position_ids = position_ids[:, -1:] + 1
            model_inputs = {
                'input_ids': next_tokens.unsqueeze(-1),
                'attention_mask': attention_mask,
                'position_ids': last_position_ids,
                'past_key_values': past_key_values,
                'use_cache': True,
            }

        # Debug statements
        print(f"Step {step}, input_ids shape: {model_inputs['input_ids'].shape}")
        print(f"Step {step}, attention_mask shape: {model_inputs['attention_mask'].shape}")
        if position_ids is not None:
            print(f"Step {step}, position_ids shape: {model_inputs['position_ids'].shape}")
            print(f"Step {step}, position_ids: {model_inputs['position_ids']}")
        if images is not None and past_key_values is None:
            print(f"Step {step}, images shape: {images.shape}")
            print(f"Step {step}, image_masks shape: {image_masks.shape}")
            print(f"Step {step}, image_input_idx shape: {image_input_idx.shape}")

        # Forward pass
        outputs = model(**model_inputs)
        next_token_logits = outputs.logits[:, -1, :]


        # Check for NaNs
        if torch.isnan(next_token_logits).any() or torch.isinf(next_token_logits).any():
            print(f"Logits contain NaN or Inf values at step {step}")
            break

        # Apply temperature
        next_token_logits = next_token_logits / temperature

        # === Adjusted Top-p Sampling Implementation ===
        # Apply top_p sampling
        sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
        probs = torch.softmax(sorted_logits, dim=-1)
        cumulative_probs = torch.cumsum(probs, dim=-1)

        # Remove tokens with cumulative probability above top_p
        sorted_indices_to_remove = cumulative_probs > top_p

        # Ensure at least min_tokens_to_keep tokens are kept
        min_tokens_to_keep = 1
        sorted_indices_to_remove[:, :min_tokens_to_keep] = False

        # Create a mask over the original logits
        indices_to_remove = torch.zeros_like(next_token_logits, dtype=torch.bool)
        indices_to_remove.scatter_(1, sorted_indices, sorted_indices_to_remove)

        # Apply the mask
        next_token_logits = next_token_logits.masked_fill(indices_to_remove, -float('inf'))
        # === End of Adjusted Section ===

        # Sample next token
        next_tokens = torch.multinomial(torch.softmax(next_token_logits, dim=-1), num_samples=1).squeeze(-1)

        # Append next tokens to generated_sequences
        generated_sequences = torch.cat([generated_sequences, next_tokens.unsqueeze(-1)], dim=1)

        # Update past_key_values
        past_key_values = outputs.past_key_values

        # Update input_ids for the next iteration (only last token)
        input_ids = next_tokens.unsqueeze(-1)

        # Update attention_mask
        new_attention_mask = torch.ones((batch_size, 1), dtype=attention_mask.dtype, device=attention_mask.device)
        attention_mask = torch.cat([attention_mask, new_attention_mask], dim=1)

        # Update position_ids
        if use_position_ids:
            # For image tokens, position_ids remain the same
            last_position_ids = position_ids[:, -1:]
            new_position_ids = last_position_ids + 1
            position_ids = torch.cat([position_ids, new_position_ids], dim=1)

        # Check for EOS token
        eos_token_id = processor.tokenizer.eos_token_id or processor.tokenizer.encode("</s>")[0]
        done |= next_tokens == eos_token_id

        # Break if all sequences are done
        if done.all():
            print(f"All sequences are done at step {step}")
            break

    # Decode generated sequences
    generated_texts = []
    for sequence in generated_sequences:
        # Remove special tokens if necessary
        text = processor.tokenizer.decode(sequence, skip_special_tokens=True)
        generated_texts.append(text)

    # Print generated texts
    for idx, text in enumerate(generated_texts):
        print(f"Generated text for input {idx + 1}: {text}")


  from .autonotebook import tqdm as notebook_tqdm


Using GPU: NVIDIA L40S


Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.33it/s]


Input input_ids device after to(device): cuda:0, dtype: torch.int64
Input images device after to(device): cuda:0, dtype: torch.float32
Input image_input_idx device after to(device): cuda:0, dtype: torch.int32
Input image_masks device after to(device): cuda:0, dtype: torch.float32
Input input_ids device after to(device): cuda:0, dtype: torch.int64
Input images device after to(device): cuda:0, dtype: torch.float32
Input image_input_idx device after to(device): cuda:0, dtype: torch.int32
Input image_masks device after to(device): cuda:0, dtype: torch.float32
Input input_ids device after to(device): cuda:0, dtype: torch.int64
Input images device after to(device): cuda:0, dtype: torch.float32
Input image_input_idx device after to(device): cuda:0, dtype: torch.int32
Input image_masks device after to(device): cuda:0, dtype: torch.float32
Processed input 0, tensor input_ids dtype: torch.int64
Processed input 0, tensor images dtype: torch.float32
Processed input 0, tensor image_input_idx dtype: