In [2]:
import spacy
from spacy import displacy
import pandas as pd
import re

2024-09-19 12:58:28.054362: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-09-19 12:58:28.054443: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-09-19 12:58:28.193028: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-09-19 12:58:28.430924: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-09-19 12:58:32.461207: I external/local_xla/xla/

In [5]:
pos_tag_pipe = spacy.load("en_core_web_sm")
pos_tag_pipe.add_pipe("merge_entities")
# pos_tag_pipe.add_pipe("merge_noun_chunks")
sdxl_captions = pd.read_csv('./meta_captions_sdxl.csv')

In [3]:
target_categories = ['Colors', 'Positional', 'Counting', 'Descriptions']
df = pd.read_csv('/notebooks/Fine-Grained-Hallucination/DrawBenchPrompts.csv')
df = df.loc[df['Category'].isin(target_categories)]

In [6]:
df = df.reset_index()

In [8]:
df.to_csv('target_prompt_dataset.csv', index=None)

In [18]:
sdxl_captions['Category'].value_counts()

Category
Reddit                 38
Colors                 25
Text                   21
DALL-E                 20
Descriptions           20
Positional             20
Counting               19
Conflicting            10
Gary Marcus et al.     10
Misspellings           10
Rare Words              7
Name: count, dtype: int64

In [6]:
class FineGrainedMetrics:
    @classmethod
    def related_to_noun(self, doc, attr, noun):
        for token in doc:
            if token.text == noun:
                subtree = [t.text for t in token.subtree]
                if attr in subtree:
                    return True
                else:
                    pass            #We might not get correct attr on the first occurence of that noun 
        return False                #If we dont get any matches return false

    @classmethod
    def colour(cls, meta, orig, generated_nouns):
        colours = set([
        'red', 'blue', 'green', 'yellow', 'black', 'white', 'gray', 'grey', 'orange',
        'pink', 'purple', 'brown', 'violet', 'indigo', 'turquoise', 'cyan', 'magenta'
        ])
        colour_avg = 0.
        orig_colour_set = set()
        for this_noun in generated_nouns:
            meta_adjectives = set([token.text for token in meta if cls.related_to_noun(meta, token.text, this_noun) and (token.dep_ in ['acomp', "amod"])])
            orig_adjectives = set([token.text for token in orig if cls.related_to_noun(orig, token.text, this_noun) and (token.dep_ in ['acomp', "amod"])])
            meta_colours = colours.intersection(meta_adjectives)
            orig_colours = colours.intersection(orig_adjectives)
            orig_colour_set.update(orig_colours)
            if len(orig_colours)==0:
                pass
            else:
                colour_avg += len(orig_colours.intersection(meta_colours))
        if len(orig_colour_set)==0:
            return -1
        else:
            return colour_avg / len(orig_colour_set)
    
    @classmethod
    def number(cls, meta, orig, generated_nouns):
        num_avg = 0.
        orig_num_set = set()
        quantities_map = {
            'a': '1',
            'an': '1',
            'the': '1',
            'couple': '2',
            'dozen': '12', 
                         }
        
        for this_noun in generated_nouns:
            meta_nums = set([token.text for token in meta if cls.related_to_noun(meta, token.text, this_noun) and (token.dep_ in ["nummod", 'det'])])   ##determiners --> 1
            orig_nums = set([token.text for token in orig if cls.related_to_noun(orig, token.text, this_noun) and (token.dep_ in ["nummod", 'det'])])   ## make mapping of words like a, and the to one and couple to two dozen to 12 etc and replace in original caption to compute
            for token in range(len(list(meta_nums))):
                if meta_nums[token].lower() in quantities_map.keys():
                    meta_nums[token] = quantities_map[meta_nums[token].lower()]

            for token in range(len(list(orig_nums))):
                if orig_nums[token].lower() in quantities_map.keys():
                    orig_nums[token] = quantities_map[orig_nums[token].lower()]
                    
            orig_num_set.update(orig_nums)
            if len(orig_nums)==0:
                pass
            else:
                num_avg += len(orig_nums.intersection(meta_nums))
        
        if len(orig_num_set) == 0:
            return -1
        else:
            return num_avg/len(orig_num_set)

    @classmethod
    def text(cls, meta, orig, _):
        indicators = ['written', 'saying', 'says']
        pattern = r'["\']([^"\']*)["\']'
        orig_matches = None
        meta_matches = None
        for token in orig:
            if token.text in indicators:
                orig_matches = re.findall(pattern, orig.text)
                meta_matches = re.findall(pattern, meta.text)
                break
        text_avg = 0.
        if orig_matches is not None:
            orig_matches = list(map(lambda s: ''.join(s.lower().split()), orig_matches))
            meta_matches = list(map(lambda s: ''.join(s.lower().split()), meta_matches))
            for this_text in orig_matches:
                for this_meta_text in meta_matches:
                    if this_text in this_meta_text:
                        text_avg += 1
            return text_avg/len(orig_matches)
        else:
            return -1
        
    @classmethod
    def extract_triplets(cls, doc):
        triplets = []
        for token in doc:
            # Identify prepositions ('prep') and their objects ('pobj')
            if token.dep_ == 'prep':
                # Find the prepositional object (pobj)
                pobj = [child for child in token.children if child.dep_ == 'pobj']
                if len(pobj) > 0:
                    # Now find the subject of the preposition (related to the governing noun or verb)
                    subject = None
                    for ancestor in token.ancestors:
                        if ancestor.dep_ in ['nsubj', 'nsubjpass']:  # Subject of the sentence
                            subject = ancestor.text
                            break

                    # Create a triplet only if we have both a subject and a prepositional object
                    if subject is not None:
                        triplets.append((subject, token.text, pobj[0].text))

        return triplets
    
    @classmethod
    def position(cls, meta, orig, generated_nouns):
        # Extract triplets from both captions
        orig_triplets = cls.extract_triplets(orig)
        meta_triplets = cls.extract_triplets(meta)
        
        # Convert triplets to sets for easier comparison
        orig_triplets_set = set(orig_triplets)
        meta_triplets_set = set(meta_triplets)
        
        # If no triplets are found in the original caption, return -1 (no recall calculation possible)
        if len(orig_triplets_set) == 0:
            return -1
        
        # Calculate the intersection of triplets between orig and meta captions
        matched_triplets = orig_triplets_set.intersection(meta_triplets_set)
        
        # Compute recall as the ratio of correctly recalled triplets
        recall = len(matched_triplets) / len(orig_triplets_set)
        
        return recall
    
    @classmethod
    def shape(cls, meta, orig, generated_nouns):
        # Define a set of shape-related adjectives
        shape_adjectives = set([
            'circular', 'square', 'triangular', 'rectangular', 'oval', 
            'hexagonal', 'pentagonal', 'octagonal', 'spherical', 'cubical',
            'cylindrical', 'conical', 'pyramidal'
        ])
        
        shape_recall = 0.  # Variable to store the recall score
        orig_shape_set = set()  # To track shapes in the original caption
        
        # Loop through each noun in generated nouns
        for this_noun in generated_nouns:
            # Get adjectives related to the current noun in 'meta' and 'orig'
            meta_shapes = set([token.text for token in meta 
                               if cls.related_to_noun(meta, token.text, this_noun) and token.dep_ in ['acomp', 'amod']])
            orig_shapes = set([token.text for token in orig 
                               if cls.related_to_noun(orig, token.text, this_noun) and token.dep_ in ['acomp', 'amod']])
            
            # Filter for shape-related adjectives
            meta_shapes_filtered = shape_adjectives.intersection(meta_shapes)
            orig_shapes_filtered = shape_adjectives.intersection(orig_shapes)
            
            # Update the original shape set to calculate recall
            orig_shape_set.update(orig_shapes_filtered)
            
            # If no shapes in the original caption, continue
            if len(orig_shapes_filtered) == 0:
                pass
            else:
                # Calculate the intersection of shapes between orig and meta captions
                shape_recall += len(orig_shapes_filtered.intersection(meta_shapes_filtered))
        
        # If no shapes were found in the original caption, return -1 (no recall calculation possible)
        if len(orig_shape_set) == 0:
            return -1
        else:
            # Return the recall score as the ratio of correctly recalled shapes
            return shape_recall / len(orig_shape_set)

In [None]:
def stage_one_metric(meta, orig):
    meta_nouns = set([token.text for token in meta if token.pos_=="NOUN" or token.pos_=="PROPN"])
    orig_nouns = set([token.text for token in orig if token.pos_=="NOUN"or token.pos_=="PROPN"])
    non_generated = orig_nouns.difference(meta_nouns)
    noun_recall = len(orig_nouns.intersection(meta_nouns))/len(orig_nouns)
    generated_nouns = orig_nouns.intersection(meta_nouns)
    return noun_recall, non_generated, generated_nouns

def stage_two_metric(meta, orig, generated_nouns, aspects=['colour', 'number', 'shape', 'position', 'text']):
    metrics = {}
    for aspect in aspects:
        metric = getattr(FineGrainedMetrics, aspect)
        metrics[aspect] = metric(meta, orig, generated_nouns)
    return metrics

def stage_three_metric():
    pass

for i, row in sdxl_captions.iterrows():
    meta_caption = row['Meta Caption']
    orig_caption = row['Prompts']
    meta_ = pos_tag_pipe(meta_caption)
    orig_ = pos_tag_pipe(orig_caption)
    n_recall, non_generated, generated_nouns = stage_one_metric(meta_, orig_)
    fine_grained_metrics = stage_two_metric(meta_, orig_, generated_nouns, aspects=['colour', 'number', 'text', 'position', 'shape'])
    if i==180:
        break
    # print(n_recall, list(non_generated))



In [None]:
orig_caption

In [2]:
meta_caption

NameError: name 'meta_caption' is not defined

In [119]:
fine_grained_metrics

{'colour': -1, 'number': -1, 'text': 0.0}

In [None]:
spacy.explain('advcl')