In [1]:
import spacy
from spacy import displacy
import pandas as pd
import re

In [2]:
pos_tag_pipe = spacy.load("en_core_web_sm")
pos_tag_pipe.add_pipe("merge_entities")
# pos_tag_pipe.add_pipe("merge_noun_chunks")
sdxl_captions = pd.read_csv('meta_captions_sdxl.csv')

In [3]:
sdxl_captions.head(5)

Unnamed: 0,image_id,gemini_caption
0,0.jpg,A classic red sports car parked against a gree...
1,1.jpg,A black car parked in front of a wall.
2,2.jpg,A pink classic car parked on a street.
3,3.jpg,A black puppy stands in a field.
4,4.jpg,A red dog resting its head.


In [4]:
# Create the empty Category column
sdxl_captions.insert(
    loc=sdxl_captions.columns.get_loc("gemini_caption") + 1,
    column="Category",
    value=""
)


In [5]:
sdxl_captions.head(5)

Unnamed: 0,image_id,gemini_caption,Category
0,0.jpg,A classic red sports car parked against a gree...,
1,1.jpg,A black car parked in front of a wall.,
2,2.jpg,A pink classic car parked on a street.,
3,3.jpg,A black puppy stands in a field.,
4,4.jpg,A red dog resting its head.,


In [6]:
df = pd.read_csv('DrawBenchPrompts.csv')
df = df.reset_index()

In [7]:
df.head(5)

Unnamed: 0,index,Prompts,Category
0,0,A red colored car.,Colors
1,1,A black colored car.,Colors
2,2,A pink colored car.,Colors
3,3,A black colored dog.,Colors
4,4,A red colored dog.,Colors


In [12]:
# First, make sure 'index' in df matches 'image_id' in sdxl_captions
df['image_id'] = df['index'].astype(str) + '.jpg'

# Then build the mapping
category_map = dict(zip(df['image_id'], df['Category']))

# Apply it to sdxl_captions
sdxl_captions['Category'] = sdxl_captions['image_id'].map(category_map)


In [13]:
sdxl_captions

Unnamed: 0,image_id,gemini_caption,Category
0,0.jpg,A classic red sports car parked against a gree...,Colors
1,1.jpg,A black car parked in front of a wall.,Colors
2,2.jpg,A pink classic car parked on a street.,Colors
3,3.jpg,A black puppy stands in a field.,Colors
4,4.jpg,A red dog resting its head.,Colors
...,...,...,...
195,195.jpg,New York City fireworks.,Text
196,196.jpg,New Year's Eve fireworks over New York City.,Text
197,197.jpg,New York City fireworks.,Text
198,198.jpg,Google Doodle celebrating Canada Day.,Text


In [15]:
target_categories = ['Colors', 'Positional', 'Counting', 'Descriptions']

df = df.loc[df['Category'].isin(target_categories)]

In [16]:
sdxl_captions = sdxl_captions.loc[sdxl_captions['Category'].isin(target_categories)]

In [17]:
df.to_csv('target_prompt_dataset.csv', index=None)

In [18]:
sdxl_captions['Category'].value_counts()

Unnamed: 0_level_0,count
Category,Unnamed: 1_level_1
Colors,25
Descriptions,20
Positional,20
Counting,19


In [20]:
df.head()

Unnamed: 0,index,Prompts,Category,image_id
0,0,A red colored car.,Colors,0.jpg
1,1,A black colored car.,Colors,1.jpg
2,2,A pink colored car.,Colors,2.jpg
3,3,A black colored dog.,Colors,3.jpg
4,4,A red colored dog.,Colors,4.jpg


In [36]:
class FineGrainedMetrics:
    @classmethod
    def related_to_noun(self, doc, attr, noun):
        for token in doc:
            if token.text == noun:
                subtree = [t.text for t in token.subtree]
                if attr in subtree:
                    return True
                else:
                    pass            #We might not get correct attr on the first occurence of that noun
        return False                #If we dont get any matches return false

    @classmethod
    def colour(cls, meta, orig, generated_nouns):
        colours = set([
        'red', 'blue', 'green', 'yellow', 'black', 'white', 'gray', 'grey', 'orange',
        'pink', 'purple', 'brown', 'violet', 'indigo', 'turquoise', 'cyan', 'magenta'
        ])
        colour_avg = 0.
        orig_colour_set = set()
        for this_noun in generated_nouns:
            meta_adjectives = set([token.text for token in meta if cls.related_to_noun(meta, token.text, this_noun) and (token.dep_ in ['acomp', "amod"])])
            orig_adjectives = set([token.text for token in orig if cls.related_to_noun(orig, token.text, this_noun) and (token.dep_ in ['acomp', "amod"])])
            meta_colours = colours.intersection(meta_adjectives)
            orig_colours = colours.intersection(orig_adjectives)
            orig_colour_set.update(orig_colours)
            if len(orig_colours)==0:
                pass
            else:
                colour_avg += len(orig_colours.intersection(meta_colours))
        if len(orig_colour_set)==0:
            return -1
        else:
            return colour_avg / len(orig_colour_set)

    @classmethod
    def number(cls, meta, orig, generated_nouns):
        num_avg = 0.
        orig_num_set = set()
        quantities_map = {
            'a': '1',
            'an': '1',
            'the': '1',
            'couple': '2',
            'dozen': '12',
                         }

        for this_noun in generated_nouns:
            meta_nums = [token.text for token in meta
                     if cls.related_to_noun(meta, token.text, this_noun)
                     and token.dep_ in ["nummod", 'det']]
            orig_nums = [token.text for token in orig
                     if cls.related_to_noun(orig, token.text, this_noun)
                     and token.dep_ in ["nummod", 'det']]  ## make mapping of words like a, and the to one and couple to two dozen to 12 etc and replace in original caption to compute
           # Apply quantities_map where applicable
            meta_nums = [quantities_map.get(t.lower(), t) for t in meta_nums]
            orig_nums = [quantities_map.get(t.lower(), t) for t in orig_nums]

           # For recall denominator and intersection
            orig_num_set.update(orig_nums)

            if len(orig_nums) == 0:
              continue
            num_avg += len(set(orig_nums).intersection(meta_nums))

        if len(orig_num_set) == 0:
          return -1
        else:
          return num_avg / len(orig_num_set)

    @classmethod
    def text(cls, meta, orig, _):
        indicators = ['written', 'saying', 'says']
        pattern = r'["\']([^"\']*)["\']'
        orig_matches = None
        meta_matches = None
        for token in orig:
            if token.text in indicators:
                orig_matches = re.findall(pattern, orig.text)
                meta_matches = re.findall(pattern, meta.text)
                break
        text_avg = 0.
        if orig_matches is not None:
            orig_matches = list(map(lambda s: ''.join(s.lower().split()), orig_matches))
            meta_matches = list(map(lambda s: ''.join(s.lower().split()), meta_matches))
            for this_text in orig_matches:
                for this_meta_text in meta_matches:
                    if this_text in this_meta_text:
                        text_avg += 1
            return text_avg/len(orig_matches)
        else:
            return -1

    @classmethod
    def extract_triplets(cls, doc):
        triplets = []
        for token in doc:
            # Identify prepositions ('prep') and their objects ('pobj')
            if token.dep_ == 'prep':
                # Find the prepositional object (pobj)
                pobj = [child for child in token.children if child.dep_ == 'pobj']
                if len(pobj) > 0:
                    # Now find the subject of the preposition (related to the governing noun or verb)
                    subject = None
                    for ancestor in token.ancestors:
                        if ancestor.dep_ in ['nsubj', 'nsubjpass']:  # Subject of the sentence
                            subject = ancestor.text
                            break

                    # Create a triplet only if we have both a subject and a prepositional object
                    if subject is not None:
                        triplets.append((subject, token.text, pobj[0].text))

        return triplets

    @classmethod
    def position(cls, meta, orig, generated_nouns):
        # Extract triplets from both captions
        orig_triplets = cls.extract_triplets(orig)
        meta_triplets = cls.extract_triplets(meta)

        # Convert triplets to sets for easier comparison
        orig_triplets_set = set(orig_triplets)
        meta_triplets_set = set(meta_triplets)

        # If no triplets are found in the original caption, return -1 (no recall calculation possible)
        if len(orig_triplets_set) == 0:
            return -1

        # Calculate the intersection of triplets between orig and meta captions
        matched_triplets = orig_triplets_set.intersection(meta_triplets_set)

        # Compute recall as the ratio of correctly recalled triplets
        recall = len(matched_triplets) / len(orig_triplets_set)

        return recall

    @classmethod
    def shape(cls, meta, orig, generated_nouns):
        # Define a set of shape-related adjectives
        shape_adjectives = set([
            'circular', 'square', 'triangular', 'rectangular', 'oval',
            'hexagonal', 'pentagonal', 'octagonal', 'spherical', 'cubical',
            'cylindrical', 'conical', 'pyramidal'
        ])

        shape_recall = 0.  # Variable to store the recall score
        orig_shape_set = set()  # To track shapes in the original caption

        # Loop through each noun in generated nouns
        for this_noun in generated_nouns:
            # Get adjectives related to the current noun in 'meta' and 'orig'
            meta_shapes = set([token.text for token in meta
                               if cls.related_to_noun(meta, token.text, this_noun) and token.dep_ in ['acomp', 'amod']])
            orig_shapes = set([token.text for token in orig
                               if cls.related_to_noun(orig, token.text, this_noun) and token.dep_ in ['acomp', 'amod']])

            # Filter for shape-related adjectives
            meta_shapes_filtered = shape_adjectives.intersection(meta_shapes)
            orig_shapes_filtered = shape_adjectives.intersection(orig_shapes)

            # Update the original shape set to calculate recall
            orig_shape_set.update(orig_shapes_filtered)

            # If no shapes in the original caption, continue
            if len(orig_shapes_filtered) == 0:
                pass
            else:
                # Calculate the intersection of shapes between orig and meta captions
                shape_recall += len(orig_shapes_filtered.intersection(meta_shapes_filtered))

        # If no shapes were found in the original caption, return -1 (no recall calculation possible)
        if len(orig_shape_set) == 0:
            return -1
        else:
            # Return the recall score as the ratio of correctly recalled shapes
            return shape_recall / len(orig_shape_set)

In [28]:
# Merge the Prompts into sdxl_captions
sdxl_captions = sdxl_captions.merge(
    df[['image_id', 'Prompts']],
    on='image_id',
    how='left'
)


In [31]:
sdxl_captions.head(17)

Unnamed: 0,image_id,gemini_caption,Category,Prompts
0,0.jpg,A classic red sports car parked against a gree...,Colors,A red colored car.
1,1.jpg,A black car parked in front of a wall.,Colors,A black colored car.
2,2.jpg,A pink classic car parked on a street.,Colors,A pink colored car.
3,3.jpg,A black puppy stands in a field.,Colors,A black colored dog.
4,4.jpg,A red dog resting its head.,Colors,A red colored dog.
5,5.jpg,A blue dog.,Colors,A blue colored dog.
6,6.jpg,Two bananas intertwined.,Colors,A green colored banana.
7,7.jpg,A single red banana.,Colors,A red colored banana.
8,8.jpg,A split banana.,Colors,A black colored banana.
9,9.jpg,A delicious-looking sandwich.,Colors,A white colored sandwich.


In [37]:
def stage_one_metric(meta, orig):
    meta_nouns = set([token.text for token in meta if token.pos_ in ("NOUN", "PROPN")])
    orig_nouns = set([token.text for token in orig if token.pos_ in ("NOUN", "PROPN")])
    non_generated = orig_nouns.difference(meta_nouns)
    generated_nouns = orig_nouns.intersection(meta_nouns)
    noun_recall = len(generated_nouns) / len(orig_nouns) if len(orig_nouns) > 0 else -1
    return noun_recall, non_generated, generated_nouns

def stage_two_metric(meta, orig, generated_nouns, aspects=['colour', 'number', 'shape', 'position', 'text']):
    metrics = {}
    for aspect in aspects:
        metric_fn = getattr(FineGrainedMetrics, aspect)
        metrics[aspect] = metric_fn(meta, orig, generated_nouns)
    return metrics

def stage_three_metric():
    # Future stage logic placeholder
    pass

# Ensure Prompts are available in sdxl_captions
if 'Prompts' not in sdxl_captions.columns:
    sdxl_captions = sdxl_captions.merge(df[['image_id', 'Prompts']], on='image_id', how='left')

# Main loop
for i, row in sdxl_captions.iterrows():
    meta_caption = row['gemini_caption']
    orig_caption = row['Prompts']
    meta_ = pos_tag_pipe(meta_caption)
    orig_ = pos_tag_pipe(orig_caption)

    n_recall, non_generated, generated_nouns = stage_one_metric(meta_, orig_)
    fine_grained_metrics = stage_two_metric(
        meta_, orig_, generated_nouns,
        aspects=['colour', 'number', 'text', 'position', 'shape']
    )

    # Optional: Print or store results here
    # print(f"Image {i} -> Recall: {n_recall:.2f}, Missing: {non_generated}")

    if i == 180:
        break



In [38]:
orig_caption

'A zebra to the right of a fire hydrant.'

In [39]:
meta_caption

'Zebra drinking from a fire hydrant.'

In [40]:
fine_grained_metrics

{'colour': -1, 'number': 1.0, 'text': -1, 'position': -1, 'shape': -1}

In [41]:
spacy.explain('advcl')

'adverbial clause modifier'