#This code takes the segments generated by Grounding dino + SAM (Flux-Dev_images_segments) and passes them to BLIP to get dense captions

In [None]:
!pip install transformers accelerate bitsandbytes torch pillow tqdm


In [None]:
# === Load BLIP2 Instruct Model in 8-bit ===
bnb_config = BitsAndBytesConfig(load_in_8bit=True)

blip_processor = InstructBlipProcessor.from_pretrained("Salesforce/instructblip-vicuna-7b")
blip_model = InstructBlipForConditionalGeneration.from_pretrained(
    "Salesforce/instructblip-vicuna-7b",
    device_map="auto",
    quantization_config=bnb_config
)

In [None]:
import os
import json
from PIL import Image

# === Paths ===
base_path = "/content/drive/MyDrive/Flux-Dev_images_segments"
output_json = "/content/drive/MyDrive/BLIP2_FLUX-DEV_CAPS.json"

# === Prompt ===
prompt = (
    "Describe the image with a focus on the intricate details of the object, "
    "including their color, shape, and number. Include any physical aspects that "
    "appear unusual or incorrect according to general knowledge."
)

# === Function to caption one image (uses existing blip_model + blip_processor) ===
def get_blip2_caption(img_path):
    img = Image.open(img_path).convert("RGB")
    inputs = blip_processor(img, prompt, return_tensors="pt").to(blip_model.device)
    out = blip_model.generate(**inputs, max_length=100, do_sample=False)
    caption = blip_processor.decode(out[0], skip_special_tokens=True)

    # 🔹 Strip prompt echo if it appears at the start
    if caption.startswith(prompt):
        caption = caption[len(prompt):].strip()

    return caption

# === Load existing captions if JSON already exists ===
if os.path.exists(output_json):
    with open(output_json, "r") as f:
        all_captions = json.load(f)
    print(f" Resuming from {output_json}")
else:
    all_captions = {}

# === Main Loop ===
folders = sorted(os.listdir(base_path))
for folder_name in folders:
    if folder_name in all_captions:  # skip already processed
        print(f"Skipping folder {folder_name} (already done)")
        continue

    folder_path = os.path.join(base_path, folder_name)
    if not os.path.isdir(folder_path):
        continue

    print(f"\n Processing folder {folder_name} ...")
    segment_captions = {}

    seg_files = sorted([f for f in os.listdir(folder_path) if f.endswith(".png")])
    for seg_file in seg_files:
        seg_path = os.path.join(folder_path, seg_file)
        print(f"    Segment: {seg_file}")
        try:
            caption = get_blip2_caption(seg_path)
        except Exception as e:
            caption = f"ERROR: {e}"
        print(f"      Caption: {caption}")
        segment_captions[seg_file] = caption

    if segment_captions:
        all_captions[folder_name] = segment_captions

        # Save after each folder
        with open(output_json, "w") as f:
            json.dump(all_captions, f, indent=2)

        print(f" Saved progress for folder {folder_name}")

print(f"\n All captions saved to {output_json}")
