In [None]:

'''
Run the following experiment:

1. Define a prompt and use Lumina 2.0 to generate two images from the prompt
2.
    (a) Use Qwen 2.5 VL 7B to describe the two images in details
    (b) Use Qwen 2.5 VL 7B to generate a description that would be correct for both the images. Ask it to Include as much information as possible in the description
    (c) For each of the generated image ask Qwen to describe the contents/activities in the image that is not present in the other image.
3. Using original prompt as the reference and output of 2a, 2b, 2c(i, ii) as the candidates. Compute BertScore of each pair (reference, candidate) and display a table of precision, recall, and f1 score nicely for all cases.

'''

import torch
from PIL import Image
import requests
from io import BytesIO
from transformers import AutoTokenizer, AutoModelForCausalLM, BertScore
from transformers import LlavaForConditionalGeneration, LlavaProcessor
from diffusers import DiffusionPipeline
import pandas as pd
from bert_score import score as bert_score
import matplotlib.pyplot as plt

# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"

def generate_images_with_lumina(prompt, num_images=2):
    """Generate images using Lumina 2.0"""
    print(f"Generating {num_images} images from prompt: {prompt}")
    
    # Load Lumina 2.0 model
    pipe = DiffusionPipeline.from_pretrained(
        "SimianLuo/LuminaLite_v2", 
        torch_dtype=torch.float16,
    )
    pipe = pipe.to(device)
    
    # Generate images
    images = []
    for i in range(num_images):
        image = pipe(prompt, num_inference_steps=30).images[0]
        images.append(image)
        # Save the generated image
        image.save(f"generated_image_{i+1}.png")
        print(f"Image {i+1} generated and saved as generated_image_{i+1}.png")
    
    return images

def load_qwen_vl():
    """Load Qwen 2.5 VL 7B model"""
    processor = LlavaProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B")
    model = LlavaForConditionalGeneration.from_pretrained(
        "Qwen/Qwen2.5-VL-7B",
        torch_dtype=torch.float16
    ).to(device)
    
    return model, processor

def get_image_description(model, processor, image, prompt):
    """Get a description of an image using Qwen VL"""
    inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch.float16)
    
    with torch.no_grad():
        output = model.generate(
            **inputs,
            max_new_tokens=300,
            do_sample=False
        )
    
    description = processor.decode(output[0], skip_special_tokens=True)
    # Clean up the output (extract just the model's response)
    if "assistant" in description.lower():
        description = description.split("assistant:", 1)[1].strip()
    
    return description.strip()

def detailed_image_description(model, processor, image):
    """Get detailed description of an image"""
    prompt = "Please describe this image in detail."
    return get_image_description(model, processor, image, prompt)

def common_description(model, processor, image1, image2):
    """Get a description that applies to both images"""
    # Using a grid of both images
    combined_image = Image.new('RGB', (image1.width + image2.width, max(image1.height, image2.height)))
    combined_image.paste(image1, (0, 0))
    combined_image.paste(image2, (image1.width, 0))
    
    prompt = "Look at these two images. Generate a detailed description that would be correct for both images. Include as much information as possible that applies to both images."
    
    return get_image_description(model, processor, combined_image, prompt)

def unique_elements(model, processor, image1, image2):
    """Describe elements in image1 that are not in image2"""
    # Using a grid of both images
    combined_image = Image.new('RGB', (image1.width + image2.width, max(image1.height, image2.height)))
    combined_image.paste(image1, (0, 0))
    combined_image.paste(image2, (image1.width, 0))
    
    prompt = "Look at these two images. The first image is on the left, the second image is on the right. Describe the contents or activities in the first image (left) that are not present in the second image (right)."
    
    return get_image_description(model, processor, combined_image, prompt)

def calculate_bert_scores(reference, candidates):
    """Calculate BertScore between reference and candidates"""
    candidate_texts = list(candidates.values())
    references = [reference] * len(candidate_texts)
    
    # Calculate BertScore
    P, R, F1 = bert_score(candidate_texts, references, lang="en", rescale_with_baseline=True)
    
    # Prepare results
    results = {
        "Description": list(candidates.keys()),
        "Precision": P.tolist(),
        "Recall": R.tolist(),
        "F1": F1.tolist()
    }
    
    return pd.DataFrame(results)

def display_images(images):
    """Display a list of images"""
    fig, axes = plt.subplots(1, len(images), figsize=(10 * len(images), 10))
    if len(images) == 1:
        axes = [axes]
    
    for i, image in enumerate(images):
        axes[i].imshow(image)
        axes[i].set_title(f"Image {i+1}")
        axes[i].axis("off")
    
    plt.tight_layout()
    plt.show()

def main():
    # 1. Define a prompt and generate images
    prompt = "A serene mountain lake at sunset with a small boat"
    images = generate_images_with_lumina(prompt)
    
    # Display the generated images
    display_images(images)
    
    # 2. Load Qwen 2.5 VL model
    model, processor = load_qwen_vl()
    
    # 2a. Get detailed descriptions of each image
    description1 = detailed_image_description(model, processor, images[0])
    description2 = detailed_image_description(model, processor, images[1])
    
    print("\nDetailed description of Image 1:")
    print(description1)
    print("\nDetailed description of Image 2:")
    print(description2)
    
    # 2b. Generate a common description
    common_desc = common_description(model, processor, images[0], images[1])
    print("\nCommon description for both images:")
    print(common_desc)
    
    # 2c. Identify unique elements in each image
    unique_in_1 = unique_elements(model, processor, images[0], images[1])
    unique_in_2 = unique_elements(model, processor, images[1], images[0])
    
    print("\nUnique elements in Image 1:")
    print(unique_in_1)
    print("\nUnique elements in Image 2:")
    print(unique_in_2)
    
    # 3. Calculate BertScore
    candidates = {
        "Image 1 Description": description1,
        "Image 2 Description": description2,
        "Common Description": common_desc,
        "Unique in Image 1": unique_in_1,
        "Unique in Image 2": unique_in_2
    }
    
    bert_scores_df = calculate_bert_scores(prompt, candidates)
    
    print("\nBertScore Results:")
    print(bert_scores_df)
    
    # Save results to CSV
    bert_scores_df.to_csv("bert_scores_results.csv", index=False)
    print("Results saved to bert_scores_results.csv")

if __name__ == "__main__":
    main()