In [8]:
from transformers import LlavaNextForConditionalGeneration, LlavaNextProcessor
import torch


In [None]:
processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
model = LlavaNextForConditionalGeneration.from_pretrained(
    "llava-hf/llava-v1.6-mistral-7b-hf",
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
)

# Save the model and processor
processor.save_pretrained("./llava_model/processor")
model.save_pretrained("./llava_model/model")

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [9]:
from PIL import Image
import requests
import torch

# Check for GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load model and processor once (not in every function call)
processor = LlavaNextProcessor.from_pretrained("./llava_model/processor")
model = LlavaNextForConditionalGeneration.from_pretrained(
    "./llava_model/model",
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    low_cpu_mem_usage=True,
    device_map="auto" if torch.cuda.is_available() else None,
)

# Move to GPU if available
if torch.cuda.is_available():
    model = model.to(device)

print(f"Model loaded on: {next(model.parameters()).device}")
print(f"Model dtype: {next(model.parameters()).dtype}")

Using device: cpu


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Model loaded on: cpu
Model dtype: torch.float32


In [10]:
from functools import lru_cache
import time

# Cache for loaded images to avoid reloading
@lru_cache(maxsize=100)
def load_image_cached(image_path):
    """Cache loaded images to avoid repeated loading"""
    if image_path.startswith("http"):
        image = Image.open(requests.get(image_path, stream=True).raw).convert("RGB")
    else:
        image = Image.open(image_path).convert("RGB")
    return image

def chat_with_image_optimized(image_paths, question, max_new_tokens=256, temperature=0.1):
    """Optimized version with GPU support, caching, and better performance"""
    start_time = time.time()
    
    # Load images with caching
    images = []
    for image_path in image_paths:
        image = load_image_cached(image_path)
        images.append(image)
    
    print(f"Image loading time: {time.time() - start_time:.2f}s")
    
    # Create conversation with multiple images
    content = [{"type": "text", "text": question}]
    
    # Add each image to the content
    for _ in images:
        content.append({"type": "image"})
    
    conversation = [
        {
            "role": "user",
            "content": content,
        }
    ]

    # Preprocess inputs
    preprocessing_start = time.time()
    prompt = processor.apply_chat_template(
        conversation,
        add_generation_prompt=True,
    )
    
    inputs = processor(
        images=images,
        text=prompt,
        return_tensors="pt",
        padding=True,  # Add padding for batch processing
    )
    
    # Move inputs to the same device as model
    inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
    
    print(f"Preprocessing time: {time.time() - preprocessing_start:.2f}s")

    # Generate response with optimizations
    generation_start = time.time()
    with torch.no_grad():
        # Use torch.inference_mode for better performance
        with torch.inference_mode():
            outputs = model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=True,
                temperature=temperature,
                pad_token_id=processor.tokenizer.eos_token_id,
                use_cache=True,  # Enable KV cache for faster generation
                num_beams=1,     # Disable beam search for speed
            )
    
    print(f"Generation time: {time.time() - generation_start:.2f}s")
    
    # Decode response
    decoding_start = time.time()
    response = processor.decode(outputs[0], skip_special_tokens=True)
    print(f"Decoding time: {time.time() - decoding_start:.2f}s")
    
    print(f"Total time: {time.time() - start_time:.2f}s")
    
    return response.split("### Response:")[-1].strip() if "### Response:" in response else response

def batch_chat_with_images(image_batches, questions, max_new_tokens=256):
    """Process multiple image-question pairs in batches for better efficiency"""
    results = []
    
    for i, (image_paths, question) in enumerate(zip(image_batches, questions)):
        print(f"\nProcessing batch {i+1}/{len(image_batches)}")
        result = chat_with_image_optimized(image_paths, question, max_new_tokens)
        results.append(result)
        
        # Optional: Clear GPU cache periodically
        if torch.cuda.is_available() and (i + 1) % 5 == 0:
            torch.cuda.empty_cache()
    
    return results

# Original function for compatibility
def chat_with_image(image_paths, question):
    """Original function - now calls optimized version"""
    return chat_with_image_optimized(image_paths, question)

In [11]:
# Test the optimized version
result = chat_with_image_optimized(
    image_paths=["friend.jpeg", "spaghetti.jpeg"], 
    question="What is in the image?",
    max_new_tokens=256,
    temperature=0.1
)
print("Response:", result)

# Example of batch processing for multiple queries
# batch_results = batch_chat_with_images(
#     image_batches=[
#         ["friend.jpeg", "spaghetti.jpeg"],
#         ["friend.jpeg"],
#         ["spaghetti.jpeg"]
#     ],
#     questions=[
#         "What is in the image?",
#         "Describe the person in detail",
#         "What type of food is this?"
#     ]
# )

Image loading time: 0.06s
Preprocessing time: 0.10s
Generation time: 2624.40s
Decoding time: 0.03s
Total time: 2624.60s
Response: [INST]  
 
What is in the image? [/INST] The image shows a group of five people standing outdoors, likely in a park or a similar natural setting. They are posing for the photo, with some of them making hand gestures. In the foreground, there is a plate of spaghetti with a sprinkle of cheese on top, suggesting that they might be enjoying a meal together. The setting appears to be casual and relaxed, with the group dressed in casual attire. 
