Based on the code from the original paper on DAAM (Tang et al., 2022).

[Github repository](https://github.com/castorini/daam/tree/main?tab=readme-ov-file)

#Setup

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import os

# Go to the DAAM folder in Google Drive
folder_path = '/content/drive/My Drive/Colab Notebooks/Stable diffusion/DAAM Wrong'
os.chdir(folder_path)
print("Current working directory: ", os.getcwd())

Current working directory:  /content/drive/My Drive/Colab Notebooks/Stable diffusion/DAAM Wrong


In [None]:
!pip install -r requirements.txt

In [None]:
import torch

torch.cuda.amp.autocast().__enter__()
torch.set_grad_enabled(False);

In [None]:
import daam

In [None]:
from diffusers import StableDiffusionPipeline

In [None]:
from daam import set_seed, trace
pipe = StableDiffusionPipeline.from_pretrained('stabilityai/stable-diffusion-2-1-base')
pipe.to('cuda:0');

In [5]:
import pandas as pd
df = pd.read_csv('final_coco_dataframe.csv')

In [6]:
# Drop the columns "preposition" and "caption"
df = df.drop(columns=['preposition', 'caption'])

# Rename the columns "wrong_preposition" to "preposition" and "wrong_caption" to "caption"
df = df.rename(columns={'wrong_preposition': 'preposition', 'wrong_caption': 'caption'})

# Take only the first 450 rows of the DataFrame
df = df.head(450)

print(df.head())

   image_id      id       subject          object     verb preposition  \
0    321333  770625     A picture  two young kids   posing       above   
1    366611  765550           dog     soccer ball   laying      out of   
2    490171  370135  a surf board       the water   riding      within   
3    491757  694991         A cat       blue eyes  sitting     outside   
4    247806  768641       A clock    a bell tower      NaN     towards   

                                             caption  
0  A picture above two young kids posing or a pic...  
1      a dog laying out of a soccer ball on a field.  
2  A man and a dog riding a surf board within the...  
3     A cat outside blue eyes sitting on a pink bed.  
4     A clock towards a bell tower of an old church.  


#Generate DAAM Maps

In [None]:
torch.cuda.empty_cache()

In [None]:
import torch
from tqdm import tqdm
from pathlib import Path

def generate_images(df, start_idx=0, batch_size=2):
    total_batches = (len(df) - start_idx) // batch_size + ((len(df) - start_idx) % batch_size > 0)
    output_folder = Path('experiments')
    output_folder.mkdir(exist_ok=True)

    for batch_num in range(total_batches):
        start = start_idx + batch_num * batch_size
        end = min(start + batch_size, len(df))
        batch_df = df.iloc[start:end]

        for _, row in tqdm(batch_df.iterrows(), total=len(batch_df), desc=f'Batch {batch_num+1}/{total_batches}'):
            id, caption = row['id'], row['caption']
            try:
                gen = set_seed(id)
                with torch.no_grad():
                    with trace(pipe) as tc:
                        out = pipe(caption, num_inference_steps=20, generator=gen)
                        exp = tc.to_experiment(output_folder, id=str(id), seed=id)
                        exp.save(output_folder, heat_maps=False)
                torch.cuda.empty_cache()
            except Exception as e:
                print(f"Error processing id {id}: {e}")

        # Save progress after each batch
        with open('progress.txt', 'w') as f:
            f.write(str(end))

# Check for existing progress and resume
try:
    with open('progress.txt', 'r') as f:
        start_index = int(f.read().strip())
except FileNotFoundError:
    start_index = 0

generate_images(df, start_idx=start_index)

In [None]:
# Set the path to the 'experiments' directory
experiment_path = Path('experiments')

# Use a generator expression to count directories
folder_count = sum(1 for entry in experiment_path.iterdir() if entry.is_dir())

print(f"There are {folder_count} folders in {experiment_path}")

There are 450 folders in experiments


# Parse and analyse

In [None]:
from matplotlib import pyplot as plt
from daam import GenerationExperiment

In [None]:
def iou(a, b, t: float = 0.1) -> float:
    intersection = (a > t) & (b > t)
    union = (a > t) | (b > t)

    i = intersection.float().sum()
    u = union.float().sum()

    if u < 1e-6:
        return 0.0
    else:
        return (i / u).item()

In [None]:
import pandas as pd
import gc
from pathlib import Path
from tqdm import tqdm
from daam import GenerationExperiment

print("DataFrame loaded with {} rows.".format(len(df)))

stats = []

# Use itertuples for more efficient row iteration
for row in tqdm(df.itertuples(index=True), total=len(df)):
    experiment_path = Path('experiments') / str(row.id)
    if experiment_path.exists():
        print(f"Loading experiment from: {experiment_path}")
        exp = GenerationExperiment.load(experiment_path)
    else:
        print(f"Experiment path does not exist: {experiment_path}")
        continue

    try:
        heat_map = exp.heat_map()
        word_maps = {}
        words_of_interest = ['preposition', 'subject', 'object', 'verb']

        for word_type in words_of_interest:
            word = getattr(row, word_type)
            if pd.notna(word):
                try:
                    word_maps[word_type] = heat_map.compute_word_heat_map(word).value.cuda()
                except ValueError as ve:
                    print(f"Could not compute heat map for {word_type}: {word}, Error: {ve}")

        pairs_of_interest = [('preposition', 'subject'), ('preposition', 'object'), ('preposition', 'verb')]
        for head_type, dep_type in pairs_of_interest:
            if head_type in word_maps and dep_type in word_maps:
                iou_value = iou(word_maps[head_type], word_maps[dep_type])
                stats.append({
                    'pair': f"{head_type}-{dep_type}",
                    'preposition': getattr(row, 'preposition'),
                    'iou': iou_value,
                    'experiment_path': str(experiment_path)
                })

    except Exception as e:
        print(f"Error processing {str(experiment_path)}: {str(e)}")

    # Clear memory after each row is processed
    del exp
    gc.collect()

stats_df = pd.DataFrame(stats)
print("Statistics collected:")
print(stats_df)

In [None]:
stats_df['preposition'] = stats_df['preposition'].str.lower()

In [None]:
stats_df.to_csv('stats_output.csv', index=False)

#Aggregate statistics

In [7]:
# Calculate the number of prepositions in the original DataFrame

df['preposition'] = df['preposition'].str.lower()
preposition_counts = df['preposition'].value_counts()

preposition_counts_df = preposition_counts.reset_index()
preposition_counts_df.columns = ['preposition', 'count']

print(preposition_counts_df)

   preposition  count
0        above     14
1           at     13
2      against     13
3      outside     12
4           of     12
..         ...    ...
56          up      4
57      behind      4
58          to      3
59     next to      3
60     through      3

[61 rows x 2 columns]


In [8]:
overall_std = preposition_counts_df['count'].std()
print(f"Overall standard deviation: {overall_std:.2f}")

Overall standard deviation: 2.63


In [None]:
# Load the statistics DataFrame
stats_df = pd.read_csv('stats_output.csv')

In [None]:
# Calculate the mean IoU across all pairs, ignoring prepositions
mean_iou_all = stats_df['iou'].mean()
rounded_mean_iou_percentage = round(mean_iou_all * 100, 4)
print("Mean IoU across all pairs:", rounded_mean_iou_percentage, "%")

Mean IoU across all pairs: 3.9373 %


In [None]:
# Calculate the mean IoU for each type of pair, ignoring prepositions
mean_iou_by_pair = stats_df.groupby('pair').agg(mean_iou=('iou', 'mean'))

# Convert to percentage
mean_iou_by_pair['mean_iou'] = (mean_iou_by_pair['mean_iou'] * 100).round(4)

# Sort from largest to smallest
mean_iou_by_pair = mean_iou_by_pair.sort_values('mean_iou', ascending=False)

# Print the DataFrame
print("Mean IoU by Pair (in %):")
print(mean_iou_by_pair)

Mean IoU by Pair (in %):
                     mean_iou
pair                         
preposition-subject    6.4334
preposition-verb       2.8330
preposition-object     2.2512


In [None]:
# Adjust display settings to show full DataFrame without truncation
pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)
pd.set_option('display.width', None)
pd.set_option('display.max_colwidth', None)

# Calculate the mean IoU for each preposition
mean_iou_by_preposition = stats_df.groupby('preposition').agg(
    mean_iou=('iou', 'mean')
)

# Convert mean IoU to percentage
mean_iou_by_preposition['mean_iou'] = (mean_iou_by_preposition['mean_iou'] * 100).round(4)

# Merge with preposition_counts_df to include the count
mean_iou_by_preposition = mean_iou_by_preposition.merge(preposition_counts_df, on='preposition')

# Sort from largest to smallest
mean_iou_by_preposition = mean_iou_by_preposition.sort_values('mean_iou', ascending=False)

# Display the first 30 entries
mean_iou_by_preposition = mean_iou_by_preposition.head(30)

# Print the DataFrame
print("Mean IoU by Preposition (in %), including counts:")
print(mean_iou_by_preposition)

Mean IoU by Preposition (in %), including counts:
   preposition  mean_iou  count
6       around   17.3033      6
49     towards   14.6391      7
2       across   14.4956      9
21      during   12.7049      7
19        down   11.7368      6
59        with   10.4236      8
14     between   10.3090      8
51  underneath    9.7202      4
45        over    9.0317      7
8           at    8.0500     13
38          on    6.9435     10
11      behind    6.5046      4
3      against    6.3417     13
32        into    6.2519      8
43     outside    5.7913     12
5        among    5.5326      9
29       in to    5.1457      8
20   down into    5.0200      5
26          in    4.7577      9
50       under    4.6821      9
0        about    4.4161      6
12      beside    4.3361      7
37         off    3.7499     11
27  in and out    3.5403      7
15      beyond    3.0795      4
7           as    2.7614      6
58       while    2.7308      8
40        onto    2.7251      8
34        near    2.65

In [None]:
# Adjust display settings to show full DataFrame without truncation
pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)
pd.set_option('display.width', None)
pd.set_option('display.max_colwidth', None)

# Calculate the mean IoU for each type of pair and each preposition
mean_iou_by_pair_and_preposition = stats_df.groupby(['pair', 'preposition']).agg(
    mean_iou=('iou', 'mean')
).reset_index()

# Convert mean IoU to percentage
mean_iou_by_pair_and_preposition['mean_iou'] = (mean_iou_by_pair_and_preposition['mean_iou'] * 100).round(4)

# Merge with preposition_counts_df to include the count
mean_iou_by_pair_and_preposition = mean_iou_by_pair_and_preposition.merge(preposition_counts_df, on='preposition')

# Sort within each 'pair' group from largest to smallest by mean IoU
mean_iou_by_pair_and_preposition = mean_iou_by_pair_and_preposition.sort_values(['pair', 'mean_iou'], ascending=[True, False])

# Use groupby on 'pair' and apply a lambda to take the top 10 prepositions for each pair
# Reset the index to keep the 'pair' and 'preposition' columns
top_10_per_pair = mean_iou_by_pair_and_preposition.groupby('pair', group_keys=False).apply(lambda x: x.head(10)).reset_index(drop=True)

# Print the DataFrame
print("Mean IoU by Pair and Preposition (in %), including preposition counts:")
print(top_10_per_pair)

Mean IoU by Pair and Preposition (in %), including preposition counts:
                   pair preposition  mean_iou  count
0    preposition-object     towards   17.2047      7
1    preposition-object        down   11.2391      6
2    preposition-object     against    8.6810     13
3    preposition-object       while    7.0220      8
4    preposition-object      around    6.2255      6
5    preposition-object         off    5.7830     11
6    preposition-object        into    5.4331      8
7    preposition-object      during    5.3276      7
8    preposition-object       under    4.6199      9
9    preposition-object      across    4.5877      9
10  preposition-subject        with   27.2618      8
11  preposition-subject      during   26.4347      7
12  preposition-subject        down   23.4674      6
13  preposition-subject      around   21.3427      6
14  preposition-subject  underneath   20.7593      4
15  preposition-subject      across   17.2964      9
16  preposition-subject     