In [None]:
!pip install transformers accelerate bitsandbytes

In [None]:
import json
import re
import torch
import pandas as pd
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM


# ========================
# Load Qwen Model
# ========================
model_name = "Qwen/Qwen2.5-7B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(model_name)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    load_in_8bit=True,
    torch_dtype=torch.float16
)

device = "cuda" if torch.cuda.is_available() else "cpu"
print(" Qwen loaded successfully in 8-bit mode")


In [None]:
# ========================
# Paths
# ========================
input_json = "/content/drive/MyDrive/Flux-Dev_mscoco_dense_caps.json"      # Replace these paths correspondingly for SDXL, SD2
output_csv = "/content/drive/MyDrive/FluxDev_QWEN_META_CAPS.csv"

# ========================
# Load JSON
# ========================
with open(input_json, "r", encoding="utf-8") as f:
    data = json.load(f)

print(f" Loaded {len(data)} images from {input_json}")


# ========================
# Meta-caption function
# ========================
def get_meta_caption(dense_captions, meta_model, meta_tokenizer, device="cuda"):
    """
    Combine multiple dense captions into one unified meta-caption.
    """
    pattern = r"<caption>(.*?)</caption>"

    # Prompt
    prompt = """I am providing you with captions for sub-regions of an image.
You need to stitch all the captions into one unified caption for the entire image.
Do not add new information. Keep it short but detailed.
Ignore mentions of backgrounds. Do not hallucinate details.
Return the final caption wrapped inside <caption></caption> tags.
"""

    # Deduplicate + filter blanks
    seen = set()
    filtered_caps = []
    for cap in dense_captions:
        cap = cap.strip()
        if cap != "" and cap not in seen:
            seen.add(cap)
            filtered_caps.append(cap)

    for i, cap in enumerate(filtered_caps):
        prompt += f"{i+1}. {cap}\n"

    if not filtered_caps:
        return ""

    messages = [
        {"role": "system", "content": "You are a meta image captioning model. You combine multiple sub-captions into one coherent grounded caption."},
        {"role": "user", "content": prompt}
    ]

    # Format for Qwen chat
    text = meta_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    model_inputs = meta_tokenizer([text], return_tensors="pt").to(device)

    # Generate
    generated_ids = meta_model.generate(
        model_inputs.input_ids,
        max_new_tokens=400,
        do_sample=False
    )

    generated_ids = [
        output_ids[len(input_ids):]
        for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
    ]
    response = meta_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

    # Extract caption
    try:
        response = re.findall(pattern, response, re.DOTALL)[0]
        return response.strip()
    except:
        return response.strip()

# ========================
# Run metacaptioning
# ========================
rows = []
for image_name, segments in tqdm(data.items(), desc="Processing Images"):
    dense_captions = list(segments.values())
    meta_caption = get_meta_caption(dense_captions, model, tokenizer, device)
    rows.append({
        "image_name": image_name,
        "Meta_caption_Qwen": meta_caption
    })

# ========================
# Save CSV
# ========================
df = pd.DataFrame(rows)
df.to_csv(output_csv, index=False, encoding="utf-8")

print(f"  Meta-captions saved to {output_csv}")
