Based on the code from the original paper on DAAM (Tang et al., 2022).

[Github repository](https://github.com/castorini/daam/tree/main?tab=readme-ov-file)

#Setup

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import os

# Go to the DAAM folder in Google Drive
folder_path = '/content/drive/My Drive/Colab Notebooks/Stable diffusion/DAAM'
os.chdir(folder_path)
print("Current working directory: ", os.getcwd())

Current working directory:  /content/drive/My Drive/Colab Notebooks/Stable diffusion/DAAM


In [None]:
!pip install -r requirements.txt

In [None]:
import torch

torch.cuda.amp.autocast().__enter__()
torch.set_grad_enabled(False);

In [None]:
import daam

In [None]:
from diffusers import StableDiffusionPipeline

In [None]:
from daam import set_seed, trace
pipe = StableDiffusionPipeline.from_pretrained('stabilityai/stable-diffusion-2-1-base')
pipe.to('cuda:0');

In [5]:
import pandas as pd
df = pd.read_csv('final_coco_dataframe.csv')

#Generate DAAM Maps

In [None]:
torch.cuda.empty_cache()

In [None]:
import torch
from tqdm import tqdm
from pathlib import Path

def generate_images(df, start_idx=0, batch_size=2):
    total_batches = (len(df) - start_idx) // batch_size + ((len(df) - start_idx) % batch_size > 0)
    output_folder = Path('experiments')
    output_folder.mkdir(exist_ok=True)

    for batch_num in range(total_batches):
        start = start_idx + batch_num * batch_size
        end = min(start + batch_size, len(df))
        batch_df = df.iloc[start:end]

        for _, row in tqdm(batch_df.iterrows(), total=len(batch_df), desc=f'Batch {batch_num+1}/{total_batches}'):
            id, caption = row['id'], row['caption']
            try:
                gen = set_seed(id)
                with torch.no_grad():
                    with trace(pipe) as tc:
                        out = pipe(caption, num_inference_steps=20, generator=gen)
                        exp = tc.to_experiment(output_folder, id=str(id), seed=id)
                        exp.save(output_folder, heat_maps=False)
                torch.cuda.empty_cache()
            except Exception as e:
                print(f"Error processing id {id}: {e}")

        # Save progress after each batch
        with open('progress.txt', 'w') as f:
            f.write(str(end))

# Check for existing progress and resume
try:
    with open('progress.txt', 'r') as f:
        start_index = int(f.read().strip())
except FileNotFoundError:
    start_index = 0

generate_images(df, start_idx=start_index)

In [None]:
# Set the path to the 'experiments' directory
experiment_path = Path('experiments')

# Use a generator expression to count directories
folder_count = sum(1 for entry in experiment_path.iterdir() if entry.is_dir())

print(f"There are {folder_count} folders in {experiment_path}")

There are 870 folders in experiments


# Parse and analyse

In [None]:
from matplotlib import pyplot as plt
from daam import GenerationExperiment

In [None]:
def iou(a, b, t: float = 0.1) -> float:
    intersection = (a > t) & (b > t)
    union = (a > t) | (b > t)

    i = intersection.float().sum()
    u = union.float().sum()

    if u < 1e-6:
        return 0.0
    else:
        return (i / u).item()

In [None]:
import pandas as pd
import gc
from pathlib import Path
from tqdm import tqdm
from daam import GenerationExperiment

print("DataFrame loaded with {} rows.".format(len(df)))

stats = []

# Use itertuples for more efficient row iteration
for row in tqdm(df.itertuples(index=True), total=len(df)):
    experiment_path = Path('experiments') / str(row.id)
    if experiment_path.exists():
        print(f"Loading experiment from: {experiment_path}")
        exp = GenerationExperiment.load(experiment_path)
    else:
        print(f"Experiment path does not exist: {experiment_path}")
        continue

    try:
        heat_map = exp.heat_map()
        word_maps = {}
        words_of_interest = ['preposition', 'subject', 'object', 'verb']

        for word_type in words_of_interest:
            word = getattr(row, word_type)
            if pd.notna(word):
                try:
                    word_maps[word_type] = heat_map.compute_word_heat_map(word).value.cuda()
                except ValueError as ve:
                    print(f"Could not compute heat map for {word_type}: {word}, Error: {ve}")

        pairs_of_interest = [('preposition', 'subject'), ('preposition', 'object'), ('preposition', 'verb')]
        for head_type, dep_type in pairs_of_interest:
            if head_type in word_maps and dep_type in word_maps:
                iou_value = iou(word_maps[head_type], word_maps[dep_type])
                stats.append({
                    'pair': f"{head_type}-{dep_type}",
                    'preposition': getattr(row, 'preposition'),
                    'iou': iou_value,
                    'experiment_path': str(experiment_path)
                })

    except Exception as e:
        print(f"Error processing {str(experiment_path)}: {str(e)}")

    # Clear memory after each row is processed
    del exp
    gc.collect()

stats_df = pd.DataFrame(stats)
print("Statistics collected:")
print(stats_df)

In [None]:
stats_df['preposition'] = stats_df['preposition'].str.lower()

In [None]:
stats_df.to_csv('stats_output.csv', index=False)

#Aggregate statistics

In [6]:
# Calculate the number of prepositions in the original DataFrame

df['preposition'] = df['preposition'].str.lower()
preposition_counts = df['preposition'].value_counts()

preposition_counts_df = preposition_counts.reset_index()
preposition_counts_df.columns = ['preposition', 'count']

print(preposition_counts_df)

   preposition  count
0           on    176
1           in    147
2         with    131
3           of    112
4      next to     39
..         ...    ...
56       in to      1
57         out      1
58          as      1
59     towards      1
60      across      1

[61 rows x 2 columns]


In [8]:
overall_std = preposition_counts_df['count'].std()
print(f"Overall standard deviation: {overall_std:.2f}")

Overall standard deviation: 35.36


In [9]:
# Load the statistics DataFrame
stats_df = pd.read_csv('stats_output.csv')

In [None]:
# Calculate the mean IoU across all pairs, ignoring prepositions
mean_iou_all = stats_df['iou'].mean()
rounded_mean_iou_percentage = round(mean_iou_all * 100, 4)
print("Mean IoU across all pairs:", rounded_mean_iou_percentage, "%")

Mean IoU across all pairs: 6.088 %


In [11]:
# Calculate the mean IoU for each type of pair, ignoring prepositions
mean_iou_by_pair = stats_df.groupby('pair').agg(mean_iou=('iou', 'mean'))

# Convert to percentage
mean_iou_by_pair['mean_iou'] = (mean_iou_by_pair['mean_iou'] * 100).round(4)

# Sort from largest to smallest
mean_iou_by_pair = mean_iou_by_pair.sort_values('mean_iou', ascending=False)

# Print the DataFrame
print("Mean IoU by Pair (in %):")
print(mean_iou_by_pair)

Mean IoU by Pair (in %):
                     mean_iou
pair                         
preposition-subject    9.5919
preposition-verb       4.3271
preposition-object     3.9066


In [None]:
# Adjust display settings to show full DataFrame without truncation
pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)
pd.set_option('display.width', None)
pd.set_option('display.max_colwidth', None)

# Calculate the mean IoU and count for each preposition
mean_iou_by_preposition = stats_df.groupby('preposition').agg(
    mean_iou=('iou', 'mean'),
    count=('iou', 'count')
)

# Convert mean IoU to percentage
mean_iou_by_preposition['mean_iou'] = (mean_iou_by_preposition['mean_iou'] * 100).round(4)

# Merge with preposition_counts_df to include the count
mean_iou_by_preposition = mean_iou_by_preposition.merge(preposition_counts_df, on='preposition', how='left')

# Sort from largest to smallest by mean IoU
mean_iou_by_preposition = mean_iou_by_preposition.sort_values('mean_iou', ascending=False)

# Display the first 30 entries
mean_iou_by_preposition = mean_iou_by_preposition.head(30)

# Print the DataFrame
print("Mean IoU by Preposition (in %), including counts:")
print(mean_iou_by_preposition)

Mean IoU by Preposition (in %), including counts:
   preposition  mean_iou  count_x  count_y
14     between   27.8215        3        1
1        above   25.1762        6        2
21      during   24.2776        3        1
45        over   19.2372       36       12
49     towards   17.8363        3        1
50       under   15.2187       17        6
59        with   14.4681      334      131
2       across   13.6069        3        1
11      behind   12.1306       22        8
47     through   11.5124       21        7
43     outside   11.3289       14        6
57     wearing   10.9702       12        5
5        among   10.2255        5        2
40        onto    9.4358        9        3
6       around    9.1239        9        3
32        into    8.2703       15        5
19        down    8.0870       75       25
16          by    6.7752       46       16
41         out    6.7213        3        1
23        from    6.3497       26        9
26          in    5.9013      409      147
8   

In [None]:
# Adjust display settings to show full DataFrame without truncation
pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)
pd.set_option('display.width', None)
pd.set_option('display.max_colwidth', None)

# Calculate the mean IoU and count for each type of pair and each preposition
mean_iou_by_pair_and_preposition = stats_df.groupby(['pair', 'preposition']).agg(
    mean_iou=('iou', 'mean'),
    preposition_count=('iou', 'count')
)

# Convert mean IoU to percentage
mean_iou_by_pair_and_preposition['mean_iou'] = (mean_iou_by_pair_and_preposition['mean_iou'] * 100).round(4)

# Reset index to merge
mean_iou_by_pair_and_preposition = mean_iou_by_pair_and_preposition.reset_index()

# Merge with preposition_counts_df to include the count
mean_iou_by_pair_and_preposition = mean_iou_by_pair_and_preposition.merge(preposition_counts_df[['preposition', 'count']], on='preposition', how='left')

# Drop the original preposition_count column
mean_iou_by_pair_and_preposition = mean_iou_by_pair_and_preposition.drop(columns=['preposition_count'])

# Sort within each 'pair' group from largest to smallest by mean IoU
mean_iou_by_pair_and_preposition = mean_iou_by_pair_and_preposition.sort_values(['pair', 'mean_iou'], ascending=[True, False])

# Use groupby on 'pair' and apply a lambda to take the top 10 prepositions for each pair
# Reset the index to keep the 'pair' and 'preposition' columns
top_10_per_pair = mean_iou_by_pair_and_preposition.groupby('pair', group_keys=False).apply(lambda x: x.head(10)).reset_index(drop=True)

# Print the DataFrame
print("Mean IoU by Pair and Preposition (in %), including preposition counts:")
print(top_10_per_pair)

Mean IoU by Pair and Preposition (in %), including preposition counts:
                   pair preposition  mean_iou  count
0    preposition-object      during   30.8208      1
1    preposition-object     wearing   15.1107      5
2    preposition-object        onto   14.5926      3
3    preposition-object     outside   13.3888      6
4    preposition-object        with   11.4602    131
5    preposition-object      beside   10.9091      3
6    preposition-object       under   10.8165      6
7    preposition-object        over   10.2918     12
8    preposition-object          by    8.5471     16
9    preposition-object       above    8.3502      2
10  preposition-subject     between   83.4646      1
11  preposition-subject      across   36.4754      1
12  preposition-subject       above   34.7092      2
13  preposition-subject      behind   27.3161      8
14  preposition-subject        over   26.7186     12
15  preposition-subject       among   25.5637      2
16  preposition-subject     