In [None]:
!pip install transformers accelerate bitsandbytes pandas pillow


In [None]:
import os, torch, pandas as pd
from PIL import Image, UnidentifiedImageError
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration, BitsAndBytesConfig

#  Paths
image_folder = "/content/drive/MyDrive/Fine-Grained-Hallucination-main/sd_2_outputs"
output_csv = "/content/drive/MyDrive/llava_mistral_captions.csv"

#  Device setup
device = "cuda" if torch.cuda.is_available() else "cpu"

#  4-bit Quantization config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_quant_type="nf4"
)

#  Load processor + model
model_id = "llava-hf/llava-v1.6-mistral-7b-hf"
processor = LlavaNextProcessor.from_pretrained(model_id)
model = LlavaNextForConditionalGeneration.from_pretrained(
    model_id,
    device_map="auto",
    quantization_config=bnb_config,
    trust_remote_code=True
).eval()

#  Caption one image
def caption_image(img: Image.Image) -> str:
    conv = [{"role": "user", "content": [{"type": "text", "text": "Generate a single line caption describing what's in the image using fewer words. Keep it short and simple, and generate only the caption."}, {"type": "image"}]}]
    prompt = processor.apply_chat_template(conv, add_generation_prompt=True)
    inputs = processor(images=img, text=prompt, return_tensors="pt").to(device)
    output = model.generate(**inputs, max_new_tokens=40)
    decoded = processor.decode(output[0], skip_special_tokens=True).strip()

    #  Remove everything before [/INST], if it exists
    if "[/INST]" in decoded:
        caption = decoded.split("[/INST]")[-1].strip()
    else:
        caption = decoded

    return caption

#  Caption loop
captions = []
image_files = sorted(fn for fn in os.listdir(image_folder) if fn.lower().endswith((".jpg", ".jpeg", ".png")))

for i, fn in enumerate(image_files, start=1):
    img_path = os.path.join(image_folder, fn)
    try:
        image = Image.open(img_path).convert("RGB")
        caption = caption_image(image)
        print(f"[{i}/{len(image_files)}]  {fn} → {caption}")
    except UnidentifiedImageError:
        caption = "Error: Corrupted image"
        print(f"[{i}/{len(image_files)}]  {fn} → Invalid image")
    except Exception as e:
        caption = f"Error: {e}"
        print(f"[{i}/{len(image_files)}]  {fn} → {caption}")

    captions.append({"image_id": fn, "caption": caption})
    torch.cuda.empty_cache()  #  Prevent memory buildup

#  Save to CSV
pd.DataFrame(captions).to_csv(output_csv, index=False)
print(f"\n✅ Captions saved to: {output_csv}")


In [None]:
import pandas as pd

# Load the CSV
df = pd.read_csv("/content/drive/MyDrive/llava_mistral_captions.csv")

# Extract numeric part of filename and sort
df["sort_key"] = df["image_id"].str.extract(r"(\d+)").astype(int)
df = df.sort_values("sort_key").drop(columns="sort_key")

# Save the sorted CSV
df.to_csv("/content/drive/MyDrive/llava_mistral_captions_sorted.csv", index=False)
