In [None]:
import os
import sys

# Add project root to Python path for module imports (needed for unpickling StyleGAN2)
project_root = '/home/jz2003/ECS289H/posthoc-generative-cbm'
if project_root not in sys.path:
    sys.path.insert(0, project_root)

# Set the current working directory
os.chdir(project_root)
print(f"Current working directory: {os.getcwd()}")
from utils.utils import save_image_grid_with_labels, get_concept_index
import argparse
import numpy as np
from pathlib import Path
import yaml
import torch
from ast import literal_eval
from torch.autograd import Variable
from torchvision.utils import save_image, make_grid
import torch.nn.functional as F
from torch import nn
from models import cbae_stygan2
from torchvision import transforms, models
from utils import gan_loss
import itertools
import warnings
import pickle
import matplotlib.pyplot as plt
from tqdm import tqdm
from eval.eval_intervention_gan import opt_int
import time

In [None]:
# dataset to evaluate on
dataset = 'celebahq'
# experiment name
expt_name = 'cbae_stygan2_thr90'
# tensorboard name
tensorboard_name = 'sup_pl_unk40_cls8'

# whether to use optimization-based interventions
optint = False
optint_eps = 0.1
optint_iters = 50

device = 'cuda:0'

In [None]:
config_file = f"./config/{expt_name}/"+dataset+".yaml"

with open(config_file, 'r') as stream:
    config = yaml.safe_load(stream)
print(f"Loaded configuration file {config_file}")

In [None]:
num_cls = 8
set_of_classes = [
    ['NOT Attractive', 'Attractive'],
    ['NO Lipstick', 'Wearing Lipstick'],
    ['Mouth Closed', 'Mouth Slightly Open'],
    ['NOT Smiling', 'Smiling'],
    ['Low Cheekbones', 'High Cheekbones'],
    ['NO Makeup', 'Heavy Makeup'],
    ['Female', 'Male'],
    ['Straight Eyebrows', 'Arched Eyebrows']
]
conc_clsf_classes = [
    'Attractive',
    'Wearing_Lipstick',
    'Mouth_Slightly_Open',
    'Smiling',
    'High_Cheekbones',
    'Heavy_Makeup',
    'Male',
    'Arched_Eyebrows',
]

config['model']['pretrained'] = 'models/checkpoints/stylegan2-celebahq-256x256.pkl'
model = cbae_stygan2.cbAE_StyGAN2(config)

cbae_ckpt_path = f'models/checkpoints/{dataset}_{expt_name}_{tensorboard_name}_cbae.pt'

model.cbae.load_state_dict(torch.load(cbae_ckpt_path, map_location='cpu'))
model.to(device)
model.eval();

In [None]:
# Set a fixed seed for reproducibility
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)

# Generate a single latent vector from the seed
z = torch.randn((1, model.gen.z_dim), device=device)
latent = model.gen.mapping(z, None, truncation_psi=1.0, truncation_cutoff=None)

# Get the original concepts for this latent
concepts = model.cbae.enc(latent)

# Generate all combinations of 8 binary concepts (2^8 = 256 combinations)
num_concepts = 8
all_combinations = list(itertools.product([0, 1], repeat=num_concepts))

print(f"Total combinations: {len(all_combinations)}")
print(f"Concepts: {conc_clsf_classes}")


# Store all generated images
all_images = []

# Generate images for all combinations
for combo in tqdm(all_combinations):
    new_concepts = concepts.clone()
    
    # Set each concept to the desired value in this combination
    for concept_idx, concept_val in enumerate(combo):
        start, end = get_concept_index(model, concept_idx)
        c_concepts = concepts[:, start:end]
        _, num_c = c_concepts.shape
        
        # Swap the max value to the concept we need
        new_c_concepts = c_concepts.clone()
        old_vals = new_c_concepts[:, concept_val].clone()
        max_val, max_ind = torch.max(new_c_concepts, dim=1)
        new_c_concepts[:, concept_val] = max_val
        for swap_idx, (curr_ind, curr_old_val) in enumerate(zip(max_ind, old_vals)):
            new_c_concepts[swap_idx, curr_ind] = curr_old_val
        
        new_concepts[:, start:end] = new_c_concepts
    
    # Decode and generate image
    new_latent = model.cbae.dec(new_concepts)
    gen_img = model.gen.synthesis(new_latent, noise_mode='const')
    gen_img = gen_img.mul(0.5).add_(0.5)
    
    all_images.append(gen_img.detach().cpu())

# Concatenate all images
all_images_tensor = torch.cat(all_images, dim=0)
print(f"Generated {len(all_images_tensor)} images with shape {all_images_tensor.shape}")


In [None]:
# Create a grid visualization of all 256 combinations (16x16 grid)
grid = make_grid(all_images_tensor, nrow=16, padding=2, normalize=True)

# Display the grid
fig, ax = plt.subplots(1, 1, figsize=(20, 20))
npimg = grid.numpy()
npimg = np.transpose(npimg, (1, 2, 0))  # CHW to HWC
ax.imshow(npimg)
ax.set_title(f'All 256 Combinations of 8 Concepts (Seed={seed})', fontsize=16)
ax.axis('off')

plt.tight_layout()
plt.show()

print(f"\nConcepts (in order): {conc_clsf_classes}")
print(f"Total combinations shown: {len(all_combinations)}")


In [None]:
# Visualize a subset of interesting combinations with labels
# Select some specific combinations to highlight
interesting_indices = [
    0,     # All concepts = 0 (NOT Attractive, NO Lipstick, Mouth Closed, NOT Smiling, Low Cheekbones, NO Makeup, Female, Straight Eyebrows)
    255,   # All concepts = 1 (Attractive, Wearing Lipstick, Mouth Slightly Open, Smiling, High Cheekbones, Heavy Makeup, Male, Arched Eyebrows)
    64,    # Only Male (0,0,0,0,0,0,1,0)
    191,   # Male with everything except Attractive (0,1,1,1,1,1,1,1)
    127,   # Female with everything except Male (0,1,1,1,1,1,1,1)
    8,     # Only Smiling (0,0,0,1,0,0,0,0)
    136,   # Only Male and Smiling (0,0,0,1,0,0,1,0)
    16,    # Only High Cheekbones (0,0,0,0,1,0,0,0)
]

# Get images for these indices
interesting_images = [all_images_tensor[i] for i in interesting_indices]
interesting_combos = [all_combinations[i] for i in interesting_indices]

# Create labels for each combination
labels = []
for combo in interesting_combos:
    label_parts = []
    for concept_idx, val in enumerate(combo):
        concept_name = conc_clsf_classes[concept_idx]
        label_parts.append(f"{concept_name}={val}")
    labels.append("\n".join(label_parts))

# Create a grid with subplots
fig, axes = plt.subplots(2, 4, figsize=(20, 10))
axes = axes.flatten()

for i, (img, combo, label) in enumerate(zip(interesting_images, interesting_combos, labels)):
    npimg = img.numpy()
    npimg = np.transpose(npimg, (1, 2, 0))  # CHW to HWC
    axes[i].imshow(npimg)
    axes[i].set_title(f"Combo {interesting_indices[i]}: {combo}", fontsize=10)
    axes[i].axis('off')

plt.suptitle(f'Selected Interesting Concept Combinations (Seed={seed})', fontsize=16)
plt.tight_layout()
plt.show()


In [None]:
import shutil

# Create output directory
output_dir = Path(f"CE_data/{expt_name}-{dataset}")
output_dir.mkdir(parents=True, exist_ok=True)

print(f"Saving images to {output_dir}")

# Save each image with its binary label as the filename
for i, (combo, img) in enumerate(tqdm(zip(all_combinations, all_images_tensor), desc="Saving images")):
    # Convert combination tuple to binary string (e.g., (0,1,1,0,1,0,1,1) -> "01101011")
    binary_label = ''.join(map(str, combo))
    
    # Save the image
    img_path = output_dir / f"{binary_label}.png"
    save_image(img, img_path)

print(f"Saved {len(all_images_tensor)} images")

# Copy the config file to the data directory
config_dest = output_dir / Path(config_file).name
shutil.copy2(config_file, config_dest)
print(f"Copied config file to {config_dest}")
