In [4]:
import pandas as pd

# ===== Paths =====
base_path = "/content/drive/MyDrive/"
ref_csv = base_path + "DrawBenchPrompts.csv"
model_files = ["meta_captions_Flux-Dev.csv", "meta_captions_sd_2.csv", "meta_captions_sdxl.csv"]

# ===== Load reference =====
ref_df = pd.read_csv(ref_csv)

for model_file in model_files:
    model_path = base_path + model_file
    model_df = pd.read_csv(model_path)

    # Find missing image_names
    missing_images = list(set(ref_df["image_name"]) - set(model_df["image_name"]))
    print(f"{model_file}: {len(missing_images)} missing images")
    if missing_images:
        print("Missing image_names:", missing_images)

        # Create new rows for missing images
        new_rows = ref_df[ref_df["image_name"].isin(missing_images)].copy()
        new_rows["Meta Caption"] = ""  # empty caption
        # Keep other columns same as ref_df
        model_df = pd.concat([model_df, new_rows[model_df.columns]], ignore_index=True)

    # Optional: sort by image_name to match reference
    model_df = model_df.set_index("image_name").reindex(ref_df["image_name"]).reset_index()

    # Save back
    model_df.to_csv(model_path, index=False)
    print(f"{model_file} updated with missing rows.\n")


meta_captions_Flux-Dev.csv: 29 missing images
Missing image_names: [8, 9, 168, 67, 70, 72, 73, 74, 75, 76, 77, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 94, 98, 99, 109]
meta_captions_Flux-Dev.csv updated with missing rows.

meta_captions_sd_2.csv: 0 missing images
meta_captions_sd_2.csv updated with missing rows.

meta_captions_sdxl.csv: 0 missing images
meta_captions_sdxl.csv updated with missing rows.



In [None]:
import pandas as pd
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed

# ===============================
# 1. Load Phi-3-mini
# ===============================
model_name = "microsoft/phi-3-mini-4k-instruct"

print("🔹 Loading Phi-3-mini model...")
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    torch_dtype="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)

pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
)

# ===============================
# 2. Function to extract noun phrases
# ===============================
def get_entities(caption: str):
    """Return a list of noun phrases for a single caption."""
    if not caption or caption.strip() == "":
        return []

    messages = [
        {"role": "user", "content": (
            "Extract all the noun phrases in the given sentence. "
            "Return them separated by commas, without rephrasing or extra text. "
            "Only keep phrases that contain a noun. "
            f"\nSentence: {caption}\nEntities:"
        )}
    ]
    generation_args = {
        "max_new_tokens": 50,
        "return_full_text": False,
        "temperature": 0.0,
        "do_sample": False
    }
    try:
        output = pipe(messages, **generation_args)
        text = output[0]['generated_text'].strip()
        entities = [ent.strip() for ent in text.split(",") if ent.strip()]
        return entities
    except Exception as e:
        print("Error processing caption:", caption, e)
        return []

# ===============================
# 3. Function to process a single CSV
# ===============================
def process_csv(file_path: str, caption_column: str):
    df = pd.read_csv(file_path)
    print(f"🔹 Processing {file_path} ({len(df)} rows)")

    entities_list = []

    # Parallel processing using ThreadPoolExecutor
    with ThreadPoolExecutor(max_workers=8) as executor:
        futures = {executor.submit(get_entities, str(caption)): i for i, caption in enumerate(df[caption_column])}
        for future in tqdm(as_completed(futures), total=len(futures), desc="Extracting entities"):
            idx = futures[future]
            try:
                entities_list.append((idx, future.result()))
            except Exception as e:
                entities_list.append((idx, []))
                print("Error at row", idx, e)

    # Sort back to original order
    entities_list.sort(key=lambda x: x[0])
    df[f"{caption_column}_entities"] = [e for _, e in entities_list]

    out_csv = file_path.replace(".csv", "_entities.csv")
    df.to_csv(out_csv, index=False)
    print(f"Saved entities to {out_csv}")

# ===============================
# 4. Run for all files
# ===============================
base_path = "/content/drive/MyDrive/"

files_to_process = {
    "Prompts": base_path + "DrawBenchPrompts.csv",
    "Flux-Dev": base_path + "meta_captions_Flux-Dev.csv",
    "sd_2": base_path + "meta_captions_sd_2.csv",
    "sdxl": base_path + "meta_captions_sdxl.csv"
}

for col_name, file_path in files_to_process.items():
    # Determine caption column
    caption_col = "Prompts" if col_name == "Prompts" else "Meta Caption"
    process_csv(file_path, caption_col)
