In [None]:
%load_ext autoreload    
%autoreload 2   

In [None]:
ref_path = "./results/reference"
select_all = False
# pre load all of the reference images
ref_images = {}
import os
import glob
from PIL import Image
import json

# Process subfolders 1-9
for i in range(1, 10):
    subfolder_path = os.path.join(ref_path, str(i))
    if os.path.exists(subfolder_path):
        for ref in glob.glob(os.path.join(subfolder_path, "*_img1.png")):
            base_name = os.path.basename(ref).replace("_img1.png", "")
            img1_path = ref
            img2_path = ref.replace("_img1.png", "_img2.png")
            #print(img1_path, img2_path)
            if os.path.exists(img2_path):
                ref_images[f"{i}/{base_name}_img1"] = Image.open(img1_path)
                ref_images[f"{i}/{base_name}_img2"] = Image.open(img2_path)

name_path = {}
for sae in glob.glob("./results/camera_ready_aggr/modesae*"):
    name = sae.split("/")[-1]
    ktrans = int(name.split("ktrans")[1].split("_")[0])
    strength = float(name.split("_str")[1])
    name_path[f"sae_{ktrans}_{strength}"] = sae

for steering in glob.glob("./results/camera_ready_aggr/modesteering*"):
    name = steering.split("/")[-1]
    strength = float(name.split("_str")[1])
    name_path[f"steer_{strength}"] = steering

for neurons in glob.glob("./results/camera_ready_aggr/modeneurons*"):
    name = neurons.split("/")[-1]
    ktrans = int(name.split("ktrans")[1].split("_")[0])
    strength = float(name.split("_str")[1])
    name_path[f"neur_{ktrans}_{strength}"] = neurons

if select_all:
    from collections import defaultdict
    selected = defaultdict(list)
    # Process subfolders 1-9
    for i in range(1, 10):
        subfolder_path = os.path.join(ref_path, str(i))
        if os.path.exists(subfolder_path):
            for ref in glob.glob(os.path.join(subfolder_path, "*_img1.png")):
                base_name = os.path.basename(ref).replace("_img1.png", "")
                selected[f"{i:d}"] += [base_name]
else:
    selected = json.load(open("riebench_selected_examples.json","r"))


edit2name = {
                 "1":"change object", # possible
                 "2":"add object", # hard -> hardcode take the position from the edit prompt
                 "3":"delete object", # possible -> utilize the thing in the braket 
                 "4":"change content", # possible -> (this should apply to all edits i guess as long as the two versions of the image are similar to each other)take edit mask in original, boost editing mask and surpress orignal mask
                 "5":"change pose", # somewhat possible -> switch
                 "6":"change color", # possible, "solved"
                 "7":"change material",# possible, very hard 
                 "8":"change background", # possible -> always use background as the grounding prompt
                 "9":"change style"}

In [None]:
name_path

In [None]:
def get_strenghts(prefix, name_path=name_path):
    """
    Returns a sorted list of unique strengths for a given prefix from the name_path dictionary.
    """
    strengths = [name.split(prefix)[1].split("_")[-1] for name in name_path.keys() if name.startswith(prefix)]
    return sorted(set(strengths), key=lambda x: float(x) if x.replace('.','',1).isdigit() else x)

def get_ktrans(prefix, name_path=name_path):
    """
    Returns a sorted list of unique ktrans values for a given prefix from the name_path dictionary.
    """
    ktrans_values = [name.split(prefix)[1].split("_")[-2] for name in name_path.keys() if name.startswith(prefix)]
    return sorted(set(ktrans_values), key=lambda x: float(x) if x.replace('.','',1).isdigit() else x)

strenghts = get_strenghts("sae")
ktrans = get_ktrans("sae")
print(strenghts)
print(ktrans)

print(get_strenghts("steer"))

In [None]:
from tqdm import tqdm
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np

def abbrev_num(num):
    if float(num) < 1000:
        return num
    else:
        return f"{float(num)/1000:.1f}k"


methods = []
method_labels = []
col_labels_1 = [f"SAE {k}" for k in get_ktrans("sae")]
col_labels_1 += ["Steering"]
seperator = None
for strength in get_strenghts("sae"):
    m_sae = [f"sae_{k}_{strength}" for k in get_ktrans("sae")]
    m_sae += [f"steer_{strength}"]
    methods.append(m_sae)
    l_sae = f"str. {strength}"
    method_labels.append(l_sae)

seperator = len(methods)
col_labels_2 = [f"Neur. {abbrev_num(k)}" for k in get_ktrans("neur")]
col_labels_2 += ["Steering"]
for strength in get_strenghts("neur"):
    if float(strength) < 21:
        m_sae = [f"neur_{k}_{strength}" for k in get_ktrans("neur")]
        m_sae += [f"steer_{strength}"]
        methods.append(m_sae)
        l_sae = f"str. {strength}"
        method_labels.append(l_sae)

#methods = methods[:3]
#method_labels = method_labels[:3]

n_cols = len(m_sae) + 2
n_rows = len(methods)
gap_size = 0.10  # Fractional height of figure reserved for the gap/line

for edit_id, names in tqdm(selected.items()):
    print("edit_id", edit_id)
    for name in names:
        if os.path.exists(f"./results/supplementary_material/camera_ready/all/{edit_id}_{name}_hq.pdf"):
            continue
        try:
            left = ref_images[f"{edit_id}/{name}_img2"]
            right = ref_images[f"{edit_id}/{name}_img1"]

            # Calculate heights for gridspec: allocate a small row for the black line/gap after 'seperator'
            heights = [1] * n_rows
            if seperator is not None and seperator < n_rows:
                # insert the gap after row 'seperator-1'
                heights = heights[:seperator] + [gap_size] + heights[seperator:]

                def _true_row_idx(row_idx):
                    # for mapping real method row_idx -> gridspec row index (since we added a gap row)
                    if row_idx < seperator:
                        return row_idx
                    return row_idx + 1
                gs_nrows = n_rows + 1  # one additional row (gap)
            else:
                def _true_row_idx(row_idx):
                    return row_idx
                gs_nrows = n_rows

            fig = plt.figure(figsize=(4 * n_cols, 4 * n_rows))
            #fig.suptitle(f"Edit ID: {edit_id}, Edit Name: {edit2name[edit_id]}, Example ID: {name}", fontsize=24, y=0.98)

            gs = gridspec.GridSpec(
                gs_nrows,
                n_cols,
                wspace=0,  # No white space between columns
                hspace=0,
                left=0,
                right=1,
                top=1,
                bottom=0,
                height_ratios=heights,
            )

            for row_idx, row_methods in enumerate(methods):
                between_images = []

                # Get images for this method type
                for method in row_methods:
                    try:
                        path = name_path[method]
                        between_images.append(Image.open(os.path.join(path, edit_id, f"{name}.png")))
                    except Exception as e:
                        print(f"Error loading image for method {method}, edit_id {edit_id}, name {name}: {e}")
                        between_images.append(None)  # Missing

                if len(between_images) < n_cols - 2:
                    between_images.extend([None] * (n_cols - 2 - len(between_images)))

                for col_idx in range(n_cols):  # left + k + right
                    ax = fig.add_subplot(gs[_true_row_idx(row_idx), col_idx])

                    # Remove all axes padding between images to eliminate 1-pixel whitespace
                    ax.set_position(gs[_true_row_idx(row_idx), col_idx].get_position(fig))

                    if col_idx == 0:
                        # Left image
                        ax.imshow(np.array(left))
                        if row_idx == 0:
                            ax.set_title("Target", fontsize=20)
                        ax.set_ylabel(method_labels[row_idx], fontsize=20)
                    elif col_idx == n_cols - 1:
                        # Right image
                        ax.imshow(np.array(right))
                        if row_idx == 0:
                            ax.set_title("Source", fontsize=20)
                    else:
                        # Between images
                        between_idx = col_idx - 1
                        if between_images[between_idx] is not None:
                            ax.imshow(np.array(between_images[between_idx]))
                            if row_idx == 0:
                                ax.set_title(col_labels_1[between_idx], fontsize=20)
                            elif row_idx== seperator:
                                ax.set_title(col_labels_2[between_idx], fontsize=20)
                        else:
                            ax.text(
                                0.5, 0.5, "Image not available",
                                horizontalalignment="center", verticalalignment="center"
                            )

                    ax.set_xticks([])
                    ax.set_yticks([])
                    ax.set_frame_on(False)
                    for spine in ax.spines.values():
                        spine.set_visible(False)
                    if col_idx != 0:
                        ax.set_ylabel('')

            # Draw a black horizontal line for separation if wanted (and gap row is added)
            if seperator is not None and seperator < n_rows:
                sep_row = seperator  # the gap row index in gridspec
                for col_idx in range(n_cols):
                    ax_gap = fig.add_subplot(gs[sep_row, col_idx])
                    ax_gap.axis("off")

            plt.subplots_adjust(wspace=0, hspace=0, left=0, right=1, top=0.92, bottom=0)
            os.makedirs(f"./results/supplementary_material/camera_ready/all/", exist_ok=True)
            plt.savefig(f"./results/supplementary_material/camera_ready/all/{edit_id}_{name}_hq.pdf", dpi=300, bbox_inches='tight', pad_inches=0)
            plt.savefig(f"./results/supplementary_material/camera_ready/all/{edit_id}_{name}_lq.pdf", dpi=150, bbox_inches='tight', pad_inches=0)
            plt.savefig(f"./results/supplementary_material/camera_ready/all/{edit_id}_{name}.jpg", bbox_inches='tight', pad_inches=0)
            plt.close()
        except Exception as e:
            print(f"Error processing grid for edit_id {edit_id}, name {name}: {e}")
            plt.close()  # Close any open figure
            continue
