In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
import warnings
from collections import defaultdict
from PIL import Image
import requests
import gc
import psutil
import os
from contextlib import contextmanager
from contextlib import nullcontext

# Suppress warnings for cleaner output
warnings.filterwarnings('ignore')

# Check for GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"\nUsing device: {device}")
if device.type == "cuda":
    print(f"GPU: {torch.cuda.get_device_name()}")
    # Enable TF32 for better performance on Ampere GPUs
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

def resize_and_pad_image(image):
    """Resize image to a fixed size and pad if necessary."""
    target_size = (224, 224)  # Standard size for many vision models
    image = image.resize(target_size, Image.Resampling.LANCZOS)
    return image

def debug_processor_output(processor):
    """Debug function to examine processor output shapes."""
    image = Image.open(requests.get("https://picsum.photos/id/237/536/354", stream=True).raw)
    image = resize_and_pad_image(image)
    
    print("\nRunning processor debug...")
    raw_inputs = processor.process(
        images=[image],
        text="Test image.",
        return_tensors="pt",
    )
    
    print("\nRaw processor outputs:")
    for k, v in raw_inputs.items():
        if isinstance(v, torch.Tensor):
            print(f"{k}: {v.shape}")
    return raw_inputs

def get_memory_usage():
    """Get current memory usage of the process"""
    process = psutil.Process(os.getpid())
    cpu_mem = process.memory_info().rss / 1024 / 1024  # MB
    gpu_mem = torch.cuda.memory_allocated() / 1024 / 1024 if torch.cuda.is_available() else 0  # MB
    return f"CPU Memory: {cpu_mem:.2f}MB, GPU Memory: {gpu_mem:.2f}MB"

@contextmanager
def batch_memory_manager():
    try:
        yield
    finally:
        torch.cuda.empty_cache()
        gc.collect()
        print(f"Memory after batch cleanup: {get_memory_usage()}")

def print_model_config(model):
    """Print detailed model configuration for debugging."""
    print("\nModel Configuration:")
    config = model.config
    print("\nConfig attributes:")
    for key, value in config.__dict__.items():
        if not key.startswith('_'):
            print(f"{key}: {value}")
    
    print("\nModel Structure:")
    print(f"Number of parameters: {sum(p.numel() for p in model.parameters()):,}")
    print(f"Number of trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
    
    print("\nModel Methods:")
    generation_methods = [method for method in dir(model) if 'generate' in method and not method.startswith('_')]
    print("Available generation methods:", generation_methods)
    
    for method in generation_methods:
        if hasattr(model, method):
            print(f"\n{method} method signature:")
            method_obj = getattr(model, method)
            if hasattr(method_obj, '__code__'):
                print(f"Arguments: {method_obj.__code__.co_varnames[:method_obj.__code__.co_argcount]}")
    
    print("\nTokenizer Information:")
    print(f"Vocabulary size: {model.config.vocab_size}")
    print(f"Model max length: {model.config.max_position_embeddings if hasattr(model.config, 'max_position_embeddings') else 'Not specified'}")

def process_single_thread(thread, processor, device):
    """Process a single thread with image and text."""
    image = Image.open(requests.get(thread["image_url"], stream=True).raw)

    with torch.cuda.amp.autocast() if device.type == "cuda" else nullcontext():
        inputs = processor.process(
            images=[image],
            text=thread["text"],
            return_tensors="pt",
        )
    
    # Process and reshape inputs
    processed_inputs = {}
    for k, v in inputs.items():
        if isinstance(v, torch.Tensor):
            if k == "input_ids":
                processed_inputs[k] = v.unsqueeze(0) if len(v.shape) == 1 else v
            elif k == "images":
                processed_inputs[k] = v.unsqueeze(0) if len(v.shape) < 4 else v
            elif k in ["image_masks", "image_input_idx"]:
                processed_inputs[k] = v.reshape(1, 2, -1) if len(v.shape) == 2 else v
            else:
                processed_inputs[k] = v
            processed_inputs[k] = processed_inputs[k].to(device)
        else:
            processed_inputs[k] = v
    
    processed_inputs["attention_mask"] = torch.ones_like(
        processed_inputs["input_ids"],
        dtype=torch.float16 if device.type == "cuda" else torch.float32
    )
    
    return processed_inputs

def process_subsequent_sequences(subsequent_sequences, model, processor, num_layers, device):
    """Process subsequent sequence generation with proper attention handling."""
    print("\n" + "="*80)
    print("ENTERING PROCESS_SUBSEQUENT_SEQUENCES")
    print(f"Number of sequences to process: {len(subsequent_sequences)}")
    print(f"Number of layers: {num_layers}")
    print(f"Device: {device}")
    print("="*80)

    with torch.no_grad():

        # Create generation config
        generation_config = GenerationConfig(
            max_new_tokens=1,
            do_sample=False,
            num_beams=1,
            use_cache=True,
            pad_token_id=processor.tokenizer.pad_token_id,
            eos_token_id=processor.tokenizer.eos_token_id,
        )
        print("\nGeneration Config:")
        print(f"max_new_tokens: {generation_config.max_new_tokens}")
        print(f"pad_token_id: {generation_config.pad_token_id}")
        print(f"eos_token_id: {generation_config.eos_token_id}")

        # Group sequences by past_key_values sequence length
        print("\nGrouping sequences by sequence length...")
        seq_len_to_sequences = defaultdict(list)
        for idx, seq in enumerate(subsequent_sequences):
            seq_len = seq["past_key_values"][0][0].shape[2]
            seq_len_to_sequences[seq_len].append(seq)
            print(f"Sequence {idx}: length = {seq_len}, generated_ids shape = {seq['generated_ids'].shape}")
        
        print(f"\nFound {len(seq_len_to_sequences)} different sequence lengths: {list(seq_len_to_sequences.keys())}")
        
        results = []
        for seq_len, sequences in seq_len_to_sequences.items():
            print("\n" + "-"*80)
            print(f"Processing sequence group with seq_len = {seq_len}")
            print(f"Number of sequences in this group: {len(sequences)}")
            
            batch_size = len(sequences)
            print("\nPreparing batch inputs...")
            print("Collecting last tokens from each sequence...")
            batch_input_ids = torch.cat([seq["generated_ids"][..., -1:] for seq in sequences], dim=0)
            print(f"batch_input_ids shape: {batch_input_ids.shape}")
            print(f"batch_input_ids values: {batch_input_ids.tolist()}")
            
            # Calculate mask length
            mask_len = seq_len + generation_config.max_new_tokens
            print(f"\nCalculating attention mask length: {seq_len} (current) + {generation_config.max_new_tokens} (new) = {mask_len}")
            
            # Create attention mask
            print("Creating attention mask...")
            batch_attention_mask = torch.ones(
                (batch_size, mask_len),
                dtype=torch.float16 if device.type == "cuda" else torch.float32,
                device=device
            )
            print(f"batch_attention_mask shape: {batch_attention_mask.shape}")
            
            # Stack past key values
            print("\nStacking past key values...")
            batch_past_key_values = []
            for layer_idx in range(num_layers):
                print(f"\nProcessing layer {layer_idx}")
                layer_keys = []
                layer_values = []
                
                for seq_idx, seq in enumerate(sequences):
                    past_key, past_value = seq["past_key_values"][layer_idx]
                    print(f"Sequence {seq_idx} - Key shape: {past_key.shape}, Value shape: {past_value.shape}")
                    layer_keys.append(past_key)
                    layer_values.append(past_value)
                
                keys = torch.cat(layer_keys, dim=0)
                values = torch.cat(layer_values, dim=0)
                print(f"Concatenated - Key shape: {keys.shape}, Value shape: {values.shape}")
                batch_past_key_values.append((keys, values))

            # Handle position IDs
            position_ids = None
            if model.config.use_position_ids:
                print("\nModel uses position IDs. Creating position tensor...")
                position_ids = torch.full(
                    (batch_size, 1),
                    seq_len,
                    dtype=torch.long,
                    device=device
                )
                print(f"position_ids shape: {position_ids.shape}")
                print(f"position_ids values: {position_ids.tolist()}")

            # Prepare model inputs
            print("\nPreparing final model inputs...")
            model_inputs = {
                "input_ids": batch_input_ids,
                "attention_mask": batch_attention_mask,
                "position_ids": position_ids,
                "past_key_values": batch_past_key_values,
                "use_cache": True,
            }

            print("\nModel input shapes:")
            for k, v in model_inputs.items():
                if isinstance(v, torch.Tensor):
                    print(f"{k}: shape={v.shape}, dtype={v.dtype}, device={v.device}")
                elif isinstance(v, list):
                    print(f"{k}: {len(v)} layers")
                    print(f"First layer shapes - Key: {v[0][0].shape}, Value: {v[0][1].shape}")

            try:
                print("\nStarting model forward pass...")
                with torch.cuda.amp.autocast() if device.type == "cuda" else nullcontext():
                    print("Running model forward pass...")
                    outputs = model(
                        **model_inputs,
                        return_dict=True
                    )
                    
                    print("\nModel outputs received:")
                    print(f"Logits shape: {outputs.logits.shape}")
                    print(f"Past key values: {len(outputs.past_key_values)} layers")
                    print(f"First layer past_key_values shapes - Key: {outputs.past_key_values[0][0].shape}, Value: {outputs.past_key_values[0][1].shape}")
                    
                    print("\nComputing next tokens...")
                    next_token_logits = outputs.logits[:, -1, :]
                    print(f"Next token logits shape: {next_token_logits.shape}")
                    
                    next_tokens = torch.argmax(next_token_logits, dim=-1)
                    next_tokens = next_tokens.unsqueeze(-1)
                    print(f"Next tokens shape: {next_tokens.shape}")
                    print(f"Next tokens values: {next_tokens.tolist()}")

                    print("\nCreating result object...")
                    result = type('GenerationResult', (), {
                        'logits': outputs.logits,
                        'past_key_values': outputs.past_key_values,
                        'generated_tokens': next_tokens
                    })
                    print("Result object created successfully")

            except Exception as e:
                print("\nERROR DURING GENERATION:")
                print(f"Exception type: {type(e)}")
                print(f"Exception message: {str(e)}")
                print("Exception args:", e.args)
                print("\nModel input shapes at time of error:")
                for k, v in model_inputs.items():
                    if isinstance(v, torch.Tensor):
                        print(f"{k}: shape={v.shape}, dtype={v.dtype}, device={v.device}")
                    elif isinstance(v, list):
                        print(f"{k}: {len(v)} layers")
                        print(f"First layer shapes - Key: {v[0][0].shape}, Value: {v[0][1].shape}")
                raise

            results.append((sequences, result))
            print(f"\nAdded results for sequence group with seq_len = {seq_len}")
            print(f"Number of results so far: {len(results)}")
        
        print("\nCompleted all sequence groups")
        print(f"Total number of results: {len(results)}")
        print("="*80)
    
    return results


def process_sequences(model, processor, pending_threads, device, max_new_tokens=200, max_batch_size=4):
    """Main sequence processing function with fixed dimension handling."""
    active_sequences = []
    finished_sequences = []
    sequence_id = 0
    num_layers = model.config.num_hidden_layers
    eos_token_id = get_eos_token_id(processor)

    with torch.no_grad():

        print("\nStarting sequence processing...")
        print(f"Number of layers: {num_layers}")
        print(f"EOS token ID: {eos_token_id}")
        print(f"Initial memory usage: {get_memory_usage()}")

        while pending_threads or active_sequences:
            # Fill batch with new threads
            while len(active_sequences) < max_batch_size and pending_threads:
                thread = pending_threads.pop(0)
                inputs = process_single_thread(thread, processor, device)
                
                sequence = {
                    "id": sequence_id,
                    "inputs": inputs,
                    "generated_ids": inputs["input_ids"].clone(),  # Should be [1, seq_len]
                    "past_key_values": None,
                    "finished": False,
                    "max_length": inputs["input_ids"].size(-1) + max_new_tokens,
                    "prompt_length": inputs["input_ids"].size(-1),
                    "original_prompt": thread["text"],  # Store original prompt for debugging
                }
                active_sequences.append(sequence)
                sequence_id += 1
                print(f"\nInitialized sequence {sequence_id}:")
                print(f"Input shape: {inputs['input_ids'].shape}")
                print(f"Generated IDs shape: {sequence['generated_ids'].shape}")
                print(f"Memory after sequence initialization: {get_memory_usage()}")

            if not active_sequences:
                break

            # Process sequences
            initial_sequences = [seq for seq in active_sequences if seq["past_key_values"] is None]
            subsequent_sequences = [seq for seq in active_sequences if seq["past_key_values"] is not None]

            # Handle initial sequences
            if initial_sequences:
                print("\n" + "="*50)
                print("Processing initial sequences")
                print("="*50)
                
                max_input_length = max(seq["inputs"]["input_ids"].size(-1) for seq in initial_sequences)
                print(f"Max input length: {max_input_length}")
                
                # Prepare batched inputs
                with batch_memory_manager():
                    batch_inputs = prepare_batch_inputs(initial_sequences, max_input_length, processor, device)
                    print("\nPrepared batch inputs:")
                    for k, v in batch_inputs.items():
                        if isinstance(v, torch.Tensor):
                            print(f"{k} shape: {v.shape}")

                    # Forward pass
                    with torch.cuda.amp.autocast() if device.type == "cuda" else nullcontext():
                        outputs = model(**batch_inputs)

                    # Update sequences
                    for idx, seq in enumerate(initial_sequences):
                        # Extract past key values
                        seq_past_key_values = [
                            (outputs.past_key_values[layer_idx][0][idx:idx+1],
                            outputs.past_key_values[layer_idx][1][idx:idx+1])
                            for layer_idx in range(num_layers)
                        ]
                        seq["past_key_values"] = seq_past_key_values

                        # Get next token and ensure correct dimensions
                        next_token_logits = outputs.logits[idx, -1, :]
                        next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
                        next_token = next_token.unsqueeze(0)  # Add batch dimension [1, 1]
                        
                        # Debug logging
                        current_text = processor.tokenizer.decode(seq['generated_ids'].squeeze(), skip_special_tokens=True)
                        next_token_text = processor.tokenizer.decode(next_token.squeeze())
                        print(f"\nOriginal prompt: {seq['original_prompt']}")
                        print(f"Previously generated: {current_text}")
                        print(f"New token: '{next_token_text}' (id: {next_token.squeeze().item()})")

                        seq["generated_ids"] = torch.cat([seq["generated_ids"], next_token], dim=1)
                        seq["inputs"]["attention_mask"] = torch.cat([
                            seq["inputs"]["attention_mask"],
                            torch.ones((1, 1), dtype=seq["inputs"]["attention_mask"].dtype, device=device)
                        ], dim=1)

                        if next_token.squeeze().item() == eos_token_id or seq["generated_ids"].shape[1] >= seq["max_length"]:
                            seq["finished"] = True
                            print(f"Sequence {seq['id']} finished")
                        
                        print(f"Memory after processing sequence {seq['id']}: {get_memory_usage()}")

            # Handle subsequent sequences
            if subsequent_sequences:
                print("\n" + "="*50)
                print("Processing subsequent sequences")
                print("="*50)
                
                with batch_memory_manager():
                    batch_results = process_subsequent_sequences(
                        subsequent_sequences, model, processor, num_layers, device
                    )
                    
                    for sequences, outputs in batch_results:
                        for idx, seq in enumerate(sequences):
                            print(f"\nSequence {seq['id']} Generation State:")
                            print(f"Original prompt: {seq['original_prompt']}")
                            current_text = processor.tokenizer.decode(seq['generated_ids'].squeeze(), skip_special_tokens=True)
                            print(f"Currently generated: {current_text}")
                            
                            # Update past key values
                            seq["past_key_values"] = [
                                (outputs.past_key_values[layer_idx][0][idx:idx+1],
                                outputs.past_key_values[layer_idx][1][idx:idx+1])
                                for layer_idx in range(num_layers)
                            ]
                            
                            # Ensure next_token has correct dimensions [1, 1]
                            next_token = outputs.generated_tokens[idx].unsqueeze(0)
                            if len(next_token.shape) == 1:
                                next_token = next_token.unsqueeze(0)
                                
                            current_text = processor.tokenizer.decode(seq['generated_ids'].squeeze(), skip_special_tokens=True)
                            next_token_text = processor.tokenizer.decode(next_token.squeeze())
                            print(f"\nOriginal prompt: {seq['original_prompt']}")
                            print(f"Previously generated: {current_text}")
                            print(f"New token: '{next_token_text}' (id: {next_token.squeeze().item()})")
                            print(f"Next token shape: {next_token.shape}")
                            
                            # Concatenate along sequence length dimension (dim=1)
                            seq["generated_ids"] = torch.cat([seq["generated_ids"], next_token], dim=1)
                            print(f"Updated generated_ids shape: {seq['generated_ids'].shape}")
                            
                            seq["inputs"]["attention_mask"] = torch.cat([
                                seq["inputs"]["attention_mask"],
                                torch.ones((1, 1), dtype=seq["inputs"]["attention_mask"].dtype, device=device)
                            ], dim=1)

                            if next_token.squeeze().item() == eos_token_id or seq["generated_ids"].shape[1] >= seq["max_length"]:
                                seq["finished"] = True
                                print(f"Sequence {seq['id']} finished")
                            
                            print(f"Memory after processing sequence {seq['id']}: {get_memory_usage()}")

            # Clean up finished sequences
            newly_finished = [seq for seq in active_sequences if seq["finished"]]
            for seq in newly_finished:
                # Explicitly clear tensors
                for k in list(seq["inputs"].keys()):
                    if isinstance(seq["inputs"][k], torch.Tensor):
                        seq["inputs"][k] = None
                seq["past_key_values"] = None
                # Don't clear generated_ids yet as we need it for decoding
                
            # Update active sequences
            active_sequences = [seq for seq in active_sequences if not seq["finished"]]
            finished_sequences.extend(newly_finished)
            
            print(f"\nActive sequences remaining: {len(active_sequences)}")
            print(f"Finished sequences: {len(finished_sequences)}")
            torch.cuda.empty_cache()
            print(f"Memory after sequence cleanup: {get_memory_usage()}")
            
            # Force garbage collection
            gc.collect()

        print("\nSequence processing completed")
        print(f"Final memory usage: {get_memory_usage()}")
    return finished_sequences

def prepare_batch_inputs(sequences, max_length, processor, device):
    """Helper function to prepare batch inputs with consistent dimensions."""
    batch_inputs = {}
    for key in sequences[0]["inputs"].keys():
        if isinstance(sequences[0]["inputs"][key], torch.Tensor):
            tensors = []
            for seq in sequences:
                if key in ["input_ids", "attention_mask"]:
                    tensor = seq["inputs"][key]
                    if tensor.size(1) < max_length:
                        padding_length = max_length - tensor.size(1)
                        padding_value = 0 if key == "attention_mask" else processor.tokenizer.pad_token_id
                        padding = torch.full(
                            (tensor.size(0), padding_length),
                            padding_value,
                            dtype=tensor.dtype,
                            device=device
                        )
                        tensor = torch.cat([tensor, padding], dim=1)
                    tensors.append(tensor)
                else:
                    tensors.append(seq["inputs"][key])
            batch_inputs[key] = torch.cat(tensors, dim=0)
        else:
            batch_inputs[key] = sequences[0]["inputs"][key]

    batch_inputs["use_cache"] = True
    return batch_inputs

def decode_sequences(finished_sequences, processor):
    """Decode and print generated sequences."""
    results = []
    for seq in finished_sequences:
        generated_ids = seq["generated_ids"].squeeze().cpu()
        
        # Decode full sequence and response
        full_text = processor.tokenizer.decode(generated_ids, skip_special_tokens=True)
        response_ids = generated_ids[seq["prompt_length"]:]
        response_text = processor.tokenizer.decode(response_ids, skip_special_tokens=True)
        
        results.append({
            "sequence_id": seq["id"],
            "full_text": full_text,
            "response_text": response_text,
            "total_tokens": len(generated_ids),
            "response_tokens": len(response_ids)
        })
        
    return results

def main():
    # Load model and processor
    model_name = "allenai/Molmo-7B-D-0924"
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        trust_remote_code=True,
        torch_dtype=torch.float16 if device.type == "cuda" else torch.float32,
        device_map="auto"
    ).to(device)

    processor = AutoProcessor.from_pretrained(
        model_name,
        trust_remote_code=True,
        torch_dtype=torch.float16 if device.type == "cuda" else torch.float32,
        device_map="auto"
    )

    # Print model configuration and debug processor output
    print_model_config(model)
    debug_output = debug_processor_output(processor)
    print("\nDebug output received. Proceeding with main script...\n")

    # Example threads
    pending_threads = [
        {
            "image_url": "https://picsum.photos/id/237/536/354",
            "text": "Describe this image."
        },
        {
            "image_url": "https://picsum.photos/id/238/536/354",
            "text": "What do you see in this picture?"
        },
    ]

    # Process sequences
    finished_sequences = process_sequences(
        model=model,
        processor=processor,
        pending_threads=pending_threads,
        device=device,
        max_new_tokens=200,
        max_batch_size=4
    )

    # Decode and print results
    results = decode_sequences(finished_sequences, processor)
    
    # Print results
    print("\nGeneration Results:")
    print("=" * 50)
    for result in results:
        print(f"\nSequence {result['sequence_id']}:")
        print(f"Full text:\n{result['full_text']}\n")
        print(f"Response:\n{result['response_text']}\n")
        print(f"Total tokens: {result['total_tokens']}")
        print(f"Response tokens: {result['response_tokens']}")
        print("-" * 50)

    # Clean up CUDA cache
    if device.type == "cuda":
        torch.cuda.empty_cache()

if __name__ == "__main__":
    main()