In [None]:
%load_ext autoreload

In [None]:
!nvidia-smi

In [None]:
%autoreload 2

import os
import gc
import time
import torch
import numpy
import pickle
import random
import threading
import numpy as np
import torchvision
import scienceplots
from numpy import dot
from tqdm import tqdm
from PIL import Image
import torch.nn as nn
import torch.optim as optim
from functools import partial
from numpy.linalg import norm
import matplotlib.pyplot as plt
from matplotlib import rcParams
import torch.nn.functional as F 
import matplotlib.colors as colors
from collections import defaultdict
from torch.utils.data import DataLoader
from torch.nn.functional import normalize
from torchvision import datasets, models, transforms

from lpips import LPIPS
from safetensors import safe_open
from diffusers import UNet2DConditionModel
from diffusers.utils import make_image_grid
from sklearn.metrics import pairwise_distances
from torchmetrics.functional.multimodal import clip_score
from torchmetrics.image.kid import KernelInceptionDistance
from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline

In [None]:
plt.style.use(['science', 'notebook', 'grid'])

plt.rcParams.update({
    'font.family': 'serif',
    'font.size': 15,
    'axes.labelsize': 15,
    'xtick.labelsize': 12,
    'ytick.labelsize': 12
})

In [None]:
%autoreload 2

# Hyperparameters
gpu_ids = [0, 1, 2] # List the indices of cuda devices
cuda_device='cuda:1' # Main cuda device, the others mentioned above are used for parallel image generation

prompt = ""
finetuned_weights_path = ""

mul_method = ""
# mul_method="ablating"
# mul_method="sdd"
# mul_method="erasure"
# mul_method = "safegen"

retain_set_eval = False

original_weights_repo_hf = "CompVis/stable-diffusion-v1-4"

seed = 42
inference_steps = 100
guidance_scale=7.5
eta=1.
device = torch.device(cuda_device)

eval_infer_timestep_chkpts = [x / 100.0 for x in range(5, 60, 10)]
eval_infer_timestep_chkpts.insert(0, 0.001)
eval_infer_timestep_chkpts.insert(1, 0.01)
print(eval_infer_timestep_chkpts)

eval_set_size = 200

image_preprocessor = transforms.Compose([
    transforms.Resize(256),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [None]:
%autoreload 2

# Helper functions

def load_frozen_weights(model, save_path, mul_method="ablating") -> None:
    if mul_method == "ablating":
        weights = torch.load(save_path)
        if 'text_encoder' in weights:
            model.text_encoder.load_state_dict(weights['text_encoder'])
        for name, params in model.unet.named_parameters():
            if name in weights['unet']:
                params.data.copy_(weights['unet'][f'{name}'])
    elif mul_method == "erasure":
        weights = torch.load(save_path)
        model.unet.load_state_dict(weights)
    elif mul_method == "sdd":
        dev = model.unet.device
        del model.unet
        model.unet = UNet2DConditionModel.from_pretrained(save_path).to(dev)
    elif mul_method == "safegen":
        with safe_open(save_path, framework="pt", device="cpu") as f:
            for key in f.keys():
                if key in model.unet.state_dict():
                    model.unet.state_dict()[key].copy_(f.get_tensor(key))

def generate_diffusion_pipeline(weights_repo, cuda_device) -> StableDiffusionPipeline:
    pipeline = StableDiffusionPipeline.from_pretrained(
        weights_repo,
        requires_safety_checker=False,
        safety_checker=None
    ).to(cuda_device)

    return pipeline

def generate_diffusion_img2img_pipeline(weights_repo, cuda_device) -> StableDiffusionImg2ImgPipeline:
    pipeline = StableDiffusionImg2ImgPipeline.from_pretrained(
        weights_repo,
        requires_safety_checker=False,
        safety_checker=None
    ).to(cuda_device)

    return pipeline

def compress_pillow_image(img, output_path, quality=85):
  new_width = 400
  width, height = img.size
  aspect_ratio = width / height
  new_height = int(new_width / aspect_ratio)
  img = img.resize((new_width, new_height))

  img.save(output_path, quality=quality)


def print_image_grid(num_images, truth_type):
    grid_path = f"./out/{mul_method}/{prompt}/{truth_type}_ground_truth_grid_{num_images}"
    images = []
    for num in range(num_images):
        image_path = f"./out/{mul_method}/{prompt}/{truth_type}_ground_truth_{num}.png"
        image = Image.open(image_path)
        images.append(image)
    
    grid = make_image_grid(images, rows=int(num_images**0.5), cols=int(num_images**0.5))
    grid.save(f"{grid_path}-full-size.png")
    display(grid)
    compress_pillow_image(grid, f"{grid_path}.png", quality=100)


def plot_ground_truth_samples(finetuned_pipeline, original_pipeline, prompt, mul_method, seed):
    fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(16, 16))

    for i, ax in enumerate(axes.flat):

        if i == 0:
            ablated_path = f"out/{mul_method}/{prompt}/ablated_ground_truth_{seed}.png"
            os.makedirs(os.path.dirname(ablated_path), exist_ok=True)

            ablated_ground_truth = None

            if os.path.exists(ablated_path):
                ablated_ground_truth = Image.open(ablated_path)
            else:
                ablated_ground_truth = finetuned_pipeline(
                    prompt=prompt,
                    num_inference_steps=inference_steps,
                    output_type="pil",
                    eta=eta,
                    generator=torch.manual_seed(seed)
                ).images[0]
                ablated_ground_truth.save(ablated_path)


            ax.imshow(ablated_ground_truth)
            ax.set_title("Ablated ground truth")

        elif i == 1:
            original_path = f"out/{mul_method}/{prompt}/original_ground_truth_{seed}.png"
            original_ground_truth = None


            if os.path.exists(original_path):
                original_ground_truth = Image.open(original_path)
            else:
                original_ground_truth = original_pipeline(
                    prompt=prompt,
                    num_inference_steps=inference_steps,
                    output_type="pil",
                    eta=eta,
                    generator=torch.manual_seed(seed)
                ).images[0]

                original_ground_truth.save(original_path)


            ax.imshow(original_ground_truth)
            ax.set_title("Original Domain Knowledge")

    plt.show()

def plot_denoised_iterations(original_pipeline, finetuned_i2i_pipeline, prompt, mul_method, eval_infer_timestep_chkpts):
    num_rows = 2
    num_cols = len(eval_infer_timestep_chkpts)

    eval_image_name = f"{mul_method}/{prompt}"

    fig, axes = plt.subplots(nrows=num_rows, ncols=num_cols, figsize=(20, 7))
    fig.subplots_adjust(hspace=0.5)

    os.makedirs('./out', exist_ok=True)

    for i, eval_infer_chkpt in enumerate(eval_infer_timestep_chkpts):
        if os.path.exists(f"out/{eval_image_name}/noised-{eval_infer_chkpt}.png"):
            axes[0, i].imshow(Image.open(f"out/{eval_image_name}/noised-{eval_infer_chkpt}.png"))
            axes[0, i].set_title(f"Denoised {round(eval_infer_chkpt*100, 2)}%")
            axes[0, i].set_xticks([])
            axes[0, i].set_yticks([])
            axes[0, i].set_xticklabels([])
            axes[0, i].set_yticklabels([])

            if os.path.exists(f"out/{eval_image_name}/denoised-{eval_infer_chkpt}.png"):
                axes[1, i].imshow(Image.open(f"out/{eval_image_name}/denoised-{eval_infer_chkpt}.png"))
                axes[1, i].set_title(f"Denoised {round((1-eval_infer_chkpt)*100, 2)}%")
                axes[1, i].set_xticks([])
                axes[1, i].set_yticks([])
                axes[1, i].set_xticklabels([])
                axes[1, i].set_yticklabels([])
            continue

        noisy_sample = original_pipeline(
            prompt=prompt,
            num_inference_steps=inference_steps,

            denoising_end=eval_infer_chkpt,
            output_type="latent",

            eta=eta,
            generator=torch.manual_seed(seed),
        ).images

        noisy_image = original_pipeline.image_processor.postprocess(noisy_sample, output_type="pil", do_denormalize=([True] * noisy_sample.shape[0]))[0]
        noisy_image.save(f"out/{eval_image_name}/noised-{eval_infer_chkpt}.png")

        axes[0, i].imshow(noisy_image)
        axes[0, i].set_title(f"Denoised {round(eval_infer_chkpt*100, 2)}% of T")
        axes[0, i].set_xticks([])
        axes[0, i].set_yticks([])
        axes[0, i].set_xticklabels([])
        axes[0, i].set_yticklabels([])
        axes[0, i].axis('off')

        denoised_image = finetuned_i2i_pipeline(
            prompt=prompt,
            num_inference_steps=inference_steps,
            guidance_scale=guidance_scale,
            eta=eta,
            generator=torch.manual_seed(seed),
            denoising_start=eval_infer_chkpt,

            image=noisy_sample
        ).images[0]

        denoised_image.save(f"out/{eval_image_name}/denoised-{eval_infer_chkpt}.png")

        axes[1, i].imshow(denoised_image)
        axes[1, i].set_title(f"Denoised {round((1-eval_infer_chkpt)*100, 2)}% of T")
        axes[1, i].set_xticks([])
        axes[1, i].set_yticks([])
        axes[1, i].set_xticklabels([])
        axes[1, i].set_yticklabels([])
        axes[1, i].axis('off')

    plt.tight_layout()
    plt.savefig(f"out/{eval_image_name}/denoised_grid.png")
    plt.show()


def generate_ground_truth_eval_set(gpus, set_count, gen_prompt, ground_truth_type, ablated=False):
    os.makedirs('./out', exist_ok=True)
    total_seeds = set_count


    def generate_image_range(start_seed, end_seed, gpu_id, pipeline):
        for seed in range(start_seed, end_seed):
            output_path = f"out/{mul_method}/{prompt}/{ground_truth_type}_ground_truth_{seed}.png"
            if not os.path.exists(output_path):
                original_ground_truth = pipeline(
                    prompt=gen_prompt,
                    num_inference_steps=inference_steps,
                    output_type="pil",

                    eta=eta,
                    generator=torch.manual_seed(seed)
                ).images[0]
                original_ground_truth.save(output_path)



    # Split seeds across GPUs
    seeds_per_gpu = total_seeds // len(gpus)
    seed_ranges = [(start, start + seeds_per_gpu) for start in range(0, total_seeds, seeds_per_gpu)]

    # Handle any remaining seeds
    remaining_seeds = total_seeds % len(gpus)
    if remaining_seeds > 0:
        seed_ranges[-1] = (seed_ranges[-1][0], seed_ranges[-1][1] + remaining_seeds)


    pipelines = []
    for gpu_id in gpus:
        pipe = generate_diffusion_pipeline(original_weights_repo_hf, f"cuda:{gpu_id}")
        if ablated:
            load_frozen_weights(pipe, finetuned_weights_path, mul_method)
        pipelines.append(pipe)

    threads = []
    for gpu_id, (start_seed, end_seed), pipeline in zip(range(len(gpus)), seed_ranges, pipelines):
        t = threading.Thread(target=generate_image_range, args=(start_seed, end_seed, gpu_id, pipeline))
        t.start()
        threads.append(t)


    for t in threads:
        t.join()

In [None]:
def evaluate_unlearning(sim_original, sim_ablated, retain_set=False):    
    assert len(sim_original) == len(sim_ablated)

    differences = []

    for i in range(len(sim_original)):
        if retain_set:
            similarity_diff = abs(sim_original[i] - sim_ablated[i])
            differences.append(similarity_diff)
        else:
            diff = abs(sim_ablated[i] - sim_original[i])
            differences.append(diff)

    average_diff = sum(differences) / len(differences)

    if retain_set:
        normalized_score = (1 - average_diff)
    else:
        normalized_score = average_diff


    return normalized_score

def calculate_scores(expr_name, x_axis, y_axis, prompt, mul_method, func, eval_set_size, feature_extractor, model, cuda_device):
    if func == calculate_kid_score:
        score1_list = func(cuda_device, "original", eval_set_size)
        score2_list = func(cuda_device, "ablated", eval_set_size)

        return score1_list, score2_list
    elif func == get_cosine_similarity_batched:
        score1_list = []
        score2_list = []
        for chkpt in tqdm(eval_infer_timestep_chkpts, desc="Calculating"):
            denoised_paths = [f"./out/{mul_method}/{prompt}/denoised-{chkpt}.png" for seed in range(eval_set_size)]
            original_paths = [f"./out/{mul_method}/{prompt}/original_ground_truth_{seed}.png" for seed in range(eval_set_size)]
            ablated_paths = [f"./out/{mul_method}/{prompt}/ablated_ground_truth_{seed}.png" for seed in range(eval_set_size)]
    
            original_scores_chkpt = func(feature_extractor, image_preprocessor, cuda_device, denoised_paths, original_paths)
            ablated_scores_chkpt = func(feature_extractor, image_preprocessor, cuda_device, denoised_paths, ablated_paths)
    
            score1_list.append(np.mean(original_scores_chkpt))
            score2_list.append(np.mean(ablated_scores_chkpt))

        print(f"CRS {func.__name__}: {evaluate_unlearning(score1_list, score2_list, retain_set=retain_set_eval)}")
        
        return score1_list, score2_list
    elif func == get_denoised_img_pred_scores:
        score1_list = []
        score2_list = []
        
        for chkpt in eval_infer_timestep_chkpts:
            denoised_path = f"./out/{mul_method}/{prompt}/denoised-{chkpt}.png"
            prob = get_denoised_img_pred_scores(denoised_path, model, image_preprocessor, cuda_device)
            score1_list.append(prob)
            score2_list.append(1-prob)
        
        print(f"CCS: {np.mean(score1_list) * 100}")
        return score1_list, score2_list
    else:
        score1_list = []
        score2_list = []

        for i in eval_infer_timestep_chkpts:
            score1 = func(feature_extractor, image_preprocessor, cuda_device, f"out/{mul_method}/{prompt}/denoised-{i}.png", f"out/{mul_method}/{prompt}/original_ground_truth_{seed}.png")
            score2 = func(feature_extractor, image_preprocessor, cuda_device, f"out/{mul_method}/{prompt}/denoised-{i}.png", f"out/{mul_method}/{prompt}/ablated_ground_truth_{seed}.png")

            score1_list.append(score1.item())
            score2_list.append(score2.item())


        return score1_list, score2_list


def plot_grid(prompt, mul_method, cuda_device, eval_set_size, model_name, plot_cases):
    num_cases = len(plot_cases)
    nrows = (num_cases + 1) // 2  # Calculate the number of rows needed
    ncols = 2 if num_cases > 1 else 1  # Set the number of columns to 2 if there are multiple cases, else 1

    # Calculate the figure size based on the number of cases
    fig_width = ncols * 6
    fig_height = nrows * 6
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(fig_width, fig_height))

    if num_cases > 1:
        axes = axes.flatten()  # Flatten the axes if there are multiple cases

    def calculate_scores_thread(i, case):
        expr_name, x_axis, y_axis, func, model, feature_extractor = case
        score1_list, score2_list = calculate_scores(expr_name, x_axis, y_axis, prompt, mul_method, func, eval_set_size, feature_extractor, model, cuda_device)

        if num_cases > 1:
            ax = axes[i]  # Get the corresponding axis for the current case
        else:
            ax = axes  # Use the single axis if there is only one case

        if len(score1_list) > 0:
            if len(score2_list) == 0:
                ax.plot(eval_infer_timestep_chkpts, score1_list, 'o--')
            else:
                ax.plot(eval_infer_timestep_chkpts, score1_list, 'o--', label='Original Domain Knowledge')
        if len(score2_list) > 0:
            ax.plot(eval_infer_timestep_chkpts, score2_list, 'o--', label='Unlearned Domain Knowledge')

        ax.set_xlabel(x_axis)
        ax.set_ylabel(y_axis)
        ax.set_title(expr_name)
        ax.legend()

        # Create the directory if it doesn't exist
        output_dir = f"./out/{mul_method}/{prompt}/"
        os.makedirs(output_dir, exist_ok=True)

        # Save the subplot as a separate PDF file
        subplot_filename = os.path.join(output_dir, f"{func.__name__}_{model_name}.pdf")
        fig_subplot = plt.figure(figsize=(6, 6))
        ax_subplot = fig_subplot.add_subplot(111)
        if len(score1_list) > 0:
            if len(score2_list) == 0:
                ax_subplot.plot(eval_infer_timestep_chkpts, score1_list, 'o--')
            else:
                ax_subplot.plot(eval_infer_timestep_chkpts, score1_list, 'o--', label='Original Domain Knowledge')
        if len(score2_list) > 0:
            ax_subplot.plot(eval_infer_timestep_chkpts, score2_list, 'o--', label='Unlearned Domain Knowledge')
        ax_subplot.set_xlabel(x_axis)
        ax_subplot.set_ylabel(y_axis)
        ax_subplot.set_title(expr_name)
        ax_subplot.legend()
        fig_subplot.savefig(subplot_filename)
        plt.close(fig_subplot)

    threads = []
    for i, case in enumerate(plot_cases):
        thread = threading.Thread(target=calculate_scores_thread, args=(i, case))
        thread.start()
        threads.append(thread)

    for thread in threads:
        thread.join()

    # Adjust spacing between subplots
    plt.subplots_adjust(wspace=0.3, hspace=0.3)
    plt.tight_layout()

    # Save the full plot as "scores.pdf" in the specified directory
    full_plot_filename = f"./out/{mul_method}/{prompt}/{model_name}-scores.pdf"
    fig.savefig(full_plot_filename)

    plt.show()

def calculate_kid_score(cuda_device, ground_truth_type, ground_truth_set_size=200, device=None):
    device = torch.device(cuda_device)

    preprocess = transforms.Compose([
        image_preprocessor,
        transforms.Resize((299, 299)),
        lambda x: (x * 255).to(torch.uint8)  # Convert to uint8 after normalization
    ])

    kid_scorer = KernelInceptionDistance(subset_size=ground_truth_set_size).to(device)

    # Load ground truth images
    ground_truth_images = []
    for j in tqdm(range(ground_truth_set_size), desc='Loading ground truth images'):
        image_path = os.path.join(f"out/{mul_method}/{prompt}/{ground_truth_type}_ground_truth_{j}.png")
        image = Image.open(image_path).convert('RGB')
        ground_truth_images.append(preprocess(image))

    ground_truth_images = torch.stack(ground_truth_images).to(device)

    kid_scores = []

    for chkpt in tqdm(eval_infer_timestep_chkpts, desc='Processing checkpoints'):
        # Load denoised image
        denoised_image_path = os.path.join(f"out/{mul_method}/{prompt}/denoised-{chkpt}.png")
        denoised_image = Image.open(denoised_image_path).convert('RGB')
        denoised_image = preprocess(denoised_image)
        denoised_images = denoised_image.repeat(ground_truth_set_size, 1, 1, 1)
        denoised_images = denoised_images.to(device)

        kid_scorer.update(ground_truth_images, real=True)
        kid_scorer.update(denoised_images, real=False)

        kid_score, _ = kid_scorer.compute()
        kid_scores.append(kid_score.item())

        del denoised_images
        gc.collect()
    return kid_scores



def get_cosine_similarity(feature_extractor, image_preprocessor, device, src1, src2):
    # Preprocess images and send them to the device
    image1 = image_preprocessor(Image.open(src1)).unsqueeze(0).to(device)
    image2 = image_preprocessor(Image.open(src2)).unsqueeze(0).to(device)
    
    feature_extractor.eval()
    feature_extractor.to(device)

    with torch.no_grad():
        # Get feature vectors, squeeze out extra dimensions if needed
        feature_vector_1 = feature_extractor(image1).squeeze()
        feature_vector_2 = feature_extractor(image2).squeeze()

        # Normalize feature vectors
        feature_vector_1 = nn.functional.normalize(feature_vector_1, dim=0)
        feature_vector_2 = nn.functional.normalize(feature_vector_2, dim=0)

        # Calculate cosine similarity
        cos_sim = (feature_vector_1 @ feature_vector_2.T).squeeze()
        sim = (1 - torch.arctan(cos_sim)) / (np.pi / 2)

    return sim.detach().cpu().numpy()

def get_cosine_similarity_batched(feature_extractor, image_preprocessor, device, denoised_paths, target_paths):
    # Preprocess and load images
    denoised_images = [image_preprocessor(Image.open(path)).unsqueeze(0).to(device) for path in denoised_paths]
    target_images = [image_preprocessor(Image.open(path)).unsqueeze(0).to(device) for path in target_paths]

    # Concatenate all images into batches
    denoised_batch = torch.cat(denoised_images, dim=0)
    target_batch = torch.cat(target_images, dim=0)

    feature_extractor.eval()
    feature_extractor.to(device)

    with torch.no_grad():
        # Extract features and remove extra dimensions if needed
        denoised_features = feature_extractor(denoised_batch)
        target_features = feature_extractor(target_batch)

        # Squeeze out extra spatial dimensions if they exist
        if denoised_features.dim() > 2:
            denoised_features = denoised_features.squeeze(dim=(2, 3))
            target_features = target_features.squeeze(dim=(2, 3))

        # Normalize feature vectors
        denoised_features = F.normalize(denoised_features, dim=1)
        target_features = F.normalize(target_features, dim=1)

        # Compute cosine similarities for each pair
        cos_sim_matrix = torch.matmul(denoised_features, target_features.T)
        sim_matrix = (1 - torch.arctan(cos_sim_matrix)) / (np.pi / 2)

        # Calculate mean similarity from the diagonal
        sim_mean = sim_matrix.diagonal().mean().item()

    return sim_mean


def get_denoised_img_pred_scores(denoised_path, model, image_preprocessor, cuda_device):
    device = torch.device(cuda_device)
    model.eval()

    denoised_image = image_preprocessor(Image.open(denoised_path)).unsqueeze(0).to(device)

    with torch.no_grad():
        outputs = model(denoised_image)
        probs = nn.Softmax(dim=1)(outputs)[:, 1]  # Probability of being the original ground truth

    return probs.item()

def finetune_model_with_contrastive_learning(prompt, mul_method, eval_set_size, cuda_device, image_preprocessor, model_name):
    # Define Triplet Dataset class with hard negative mining
    class TripletDataset(torch.utils.data.Dataset):
        def __init__(self, image_paths, labels, transform, num_negatives=3):
            self.image_paths = image_paths
            self.labels = labels
            self.transform = transform
            self.num_negatives = num_negatives
            self.label_to_indices = defaultdict(list)
            for idx, label in enumerate(self.labels):
                self.label_to_indices[label].append(idx)

        def __getitem__(self, index):
            anchor_path = self.image_paths[index]
            anchor_label = self.labels[index]
            anchor_image = Image.open(anchor_path).convert('RGB')
            anchor_image = self.transform(anchor_image)

            # Select positive sample from same class, avoiding identical images
            positive_indices = [i for i in self.label_to_indices[anchor_label] if i != index]
            if not positive_indices:
                positive_index = index
            else:
                positive_index = random.choice(positive_indices)
            
            positive_path = self.image_paths[positive_index]
            positive_image = Image.open(positive_path).convert('RGB')
            positive_image = self.transform(positive_image)

            # Select negative samples
            negative_label = 1 - anchor_label  # Assuming binary labels
            negative_indices = random.sample(self.label_to_indices[negative_label], 
                                             min(self.num_negatives, len(self.label_to_indices[negative_label])))
            
            negative_images = []
            for neg_idx in negative_indices:
                negative_path = self.image_paths[neg_idx]
                negative_image = Image.open(negative_path).convert('RGB')
                negative_images.append(self.transform(negative_image))

            # If we don't have enough negative samples, duplicate the last one
            while len(negative_images) < self.num_negatives:
                negative_images.append(negative_images[-1])

            return anchor_image, positive_image, torch.stack(negative_images), anchor_label

        def __len__(self):
            return len(self.image_paths)

    # Standard dataset remains the same
    class StandardDataset(torch.utils.data.Dataset):
        def __init__(self, image_paths, labels, transform):
            self.image_paths = image_paths
            self.labels = labels
            self.transform = transform

        def __getitem__(self, index):
            image_path = self.image_paths[index]
            image = Image.open(image_path).convert('RGB')
            image = self.transform(image)
            label = self.labels[index]
            return image, label

        def __len__(self):
            return len(self.image_paths)

    # Create datasets
    ablated_paths = [f"./out/{mul_method}/{prompt}/ablated_ground_truth_{seed}.png" for seed in range(eval_set_size)]
    original_paths = [f"./out/{mul_method}/{prompt}/original_ground_truth_{seed}.png" for seed in range(eval_set_size)]

    ablated_labels = [0] * eval_set_size
    original_labels = [1] * eval_set_size

    image_paths = ablated_paths + original_paths
    labels = ablated_labels + original_labels

    # Split indices for train and validation
    dataset_size = len(image_paths)
    indices = list(range(dataset_size))
    split = int(np.floor(0.2 * dataset_size))
    np.random.shuffle(indices)
    train_indices, val_indices = indices[split:], indices[:split]

    # Create datasets
    train_dataset = TripletDataset([image_paths[i] for i in train_indices], 
                                   [labels[i] for i in train_indices], 
                                   image_preprocessor)
    val_dataset = StandardDataset([image_paths[i] for i in val_indices], 
                                  [labels[i] for i in val_indices], 
                                  image_preprocessor)

    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

    # Initialize model and move to device
    device = torch.device(cuda_device)
    if model_name == 'densenet':
        model = models.densenet121(weights='IMAGENET1K_V1')
        num_ftrs = model.classifier.in_features
        model.classifier = nn.Linear(num_ftrs, 2)
    elif model_name == 'efficientnet':
        model = models.efficientnet_b0(weights='IMAGENET1K_V1')
        num_ftrs = model.classifier[1].in_features
        model.classifier[1] = nn.Linear(num_ftrs, 2)
    elif model_name == 'resnet':
        model = models.resnet18(weights='IMAGENET1K_V1')
        num_ftrs = model.fc.in_features
        model.fc = nn.Linear(num_ftrs, 2)
    else:
        raise ValueError("Model name must be 'densenet', 'efficientnet', or 'resnet'.")

    model = model.to(device)

    # Define a function to extract features
    def extract_features(model, x):
        if model_name == 'densenet':
            features = model.features(x)
            out = nn.functional.relu(features, inplace=True)
            out = nn.functional.adaptive_avg_pool2d(out, (1, 1))
            out = torch.flatten(out, 1)
        elif model_name == 'efficientnet':
            features = model.features(x)
            out = nn.functional.adaptive_avg_pool2d(features, (1, 1))
            out = torch.flatten(out, 1)
        elif model_name == 'resnet':
            x = model.conv1(x)
            x = model.bn1(x)
            x = model.relu(x)
            x = model.maxpool(x)
    
            x = model.layer1(x)
            x = model.layer2(x)
            x = model.layer3(x)
            x = model.layer4(x)
    
            x = model.avgpool(x)
            x = torch.flatten(x, 1)
            out = x
        else:
            raise ValueError("Model name must be 'densenet', 'efficientnet', or 'resnet'.")
        return out
    
    # Define improved triplet loss
    class TripletLoss(nn.Module):
        def __init__(self, margin=1.0):
            super(TripletLoss, self).__init__()
            self.margin = margin
        
        def forward(self, anchor, positive, negatives):
            positive_dist = torch.norm(anchor - positive, dim=1)
            negative_dists = torch.norm(anchor.unsqueeze(1) - negatives, dim=2)
            hardest_negative_dist, _ = torch.min(negative_dists, dim=1)
            loss = torch.clamp(positive_dist - hardest_negative_dist + self.margin, min=0.0)
            return loss.mean()

    # Define loss functions and optimizer
    triplet_criterion = TripletLoss(margin=1.0)
    classification_criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

    # Training loop
    num_epochs = 10
    train_total_losses = []
    train_triplet_losses = []
    train_classification_losses = []
    val_accuracies = []

    for epoch in range(num_epochs):
        running_loss = 0.0
        running_triplet_loss = 0.0
        running_classification_loss = 0.0

        model.train()
        for anchor, positive, negatives, labels in train_loader:
            anchor, positive, negatives, labels = anchor.to(device), positive.to(device), negatives.to(device), labels.to(device)
            
            optimizer.zero_grad()

            # Extract features
            anchor_features = extract_features(model, anchor)
            positive_features = extract_features(model, positive)
            
            # Process negative samples
            batch_size, num_negatives, C, H, W = negatives.size()
            negatives_reshaped = negatives.view(-1, C, H, W)
            negative_features = extract_features(model, negatives_reshaped)
            negative_features = negative_features.view(batch_size, num_negatives, -1)

            # Compute losses
            triplet_loss = triplet_criterion(anchor_features, positive_features, negative_features)
            if model_name == 'densenet' or model_name == 'efficientnet':
                classification_outputs = model.classifier(anchor_features)
            elif model_name == 'resnet':
                classification_outputs = model.fc(anchor_features)
            classification_loss = classification_criterion(classification_outputs, labels)
            total_loss = triplet_loss + classification_loss

            # Backward pass
            total_loss.backward()
            optimizer.step()

            running_loss += total_loss.item()
            running_triplet_loss += triplet_loss.item()
            running_classification_loss += classification_loss.item()

        # Calculate epoch losses
        epoch_loss = running_loss / len(train_loader)
        epoch_triplet_loss = running_triplet_loss / len(train_loader)
        epoch_classification_loss = running_classification_loss / len(train_loader)

        train_total_losses.append(epoch_loss)
        train_triplet_losses.append(epoch_triplet_loss)
        train_classification_losses.append(epoch_classification_loss)

        print(f"Epoch {epoch+1}/{num_epochs}")
        print(f"Total Loss: {epoch_loss:.4f}, Triplet Loss: {epoch_triplet_loss:.4f}, Classification Loss: {epoch_classification_loss:.4f}")

        # Validation
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_accuracy = 100 * correct / total
        val_accuracies.append(val_accuracy)
        print(f"Validation Accuracy: {val_accuracy:.2f}%")

    # Save plots
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.plot(range(1, num_epochs+1), train_total_losses, 'o--', label='Total Loss')
    plt.plot(range(1, num_epochs+1), train_triplet_losses, 'o--', label='Triplet Loss')
    plt.plot(range(1, num_epochs+1), train_classification_losses, 'o--', label='Classification Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Losses')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(range(1, num_epochs+1), val_accuracies, 'o--', label='Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.title('Validation Accuracy')
    plt.legend()

    plt.tight_layout()
    os.makedirs(f"out/{mul_method}/{prompt}/", exist_ok=True)
    plt.savefig(f"out/{mul_method}/{prompt}/{eval_set_size}-training_curves-{model_name}.pdf")
    plt.show()
    
    # Save model state dict without modification
    torch.save(model.state_dict(), f"out/{mul_method}/{prompt}/{eval_set_size}-finetuned_model-{model_name}.pth")

def load_model(prompt, mul_method, eval_set_size, cuda_device, image_preprocessor, model_name):
    device = torch.device(cuda_device)
    weights_file = f"./out/{mul_method}/{prompt}/{eval_set_size}-finetuned_model-{model_name}.pth"

    if not os.path.exists(weights_file):
        print(f"Weights file {weights_file} does not exist. Fine-tuning the model...")
        finetune_model_with_contrastive_learning(prompt, mul_method, eval_set_size, cuda_device, image_preprocessor, model_name)

    if model_name == 'densenet':
        model = models.densenet121(pretrained=False)
        num_ftrs = model.classifier.in_features
        model.classifier = nn.Linear(num_ftrs, 2)
    elif model_name == 'efficientnet':
        model = models.efficientnet_b0(pretrained=False)
        num_ftrs = model.classifier[1].in_features
        model.classifier[1] = nn.Linear(num_ftrs, 2)
    elif model_name == 'resnet':
        model = models.resnet18(pretrained=False)
        num_ftrs = model.fc.in_features
        model.fc = nn.Linear(num_ftrs, 2)
    else:
        raise ValueError("Model name must be 'densenet', 'efficientnet', or 'resnet'.")

    model.load_state_dict(torch.load(weights_file, map_location=cuda_device))
    model = model.to(device)
    model.eval()

    if model_name == 'densenet':
        feature_extractor = nn.Sequential(
            model.features,
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(1)
        )
    elif model_name == 'efficientnet':
        feature_extractor = nn.Sequential(
            model.features,
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(1)
        )
    elif model_name == 'resnet':
        feature_extractor = nn.Sequential(
            *list(model.children())[:-1],  # All layers except the final classification layer
            nn.Flatten(1)
        )
    else:
        raise ValueError("Model name must be 'densenet', 'efficientnet', or 'resnet'.")
    
    feature_extractor.eval()
    feature_extractor.to(device)  # Ensure feature extractor is on the correct device

    return model, feature_extractor

In [None]:
%autoreload 2

original_pipeline = generate_diffusion_pipeline(original_weights_repo_hf, cuda_device)
finetuned_i2i_pipeline = generate_diffusion_img2img_pipeline(original_weights_repo_hf, cuda_device)
finetuned_pipeline = generate_diffusion_pipeline(original_weights_repo_hf, cuda_device)

load_frozen_weights(finetuned_i2i_pipeline, finetuned_weights_path, mul_method)
load_frozen_weights(finetuned_pipeline, finetuned_weights_path, mul_method)

In [None]:
%autoreload 2
plot_ground_truth_samples(finetuned_pipeline, original_pipeline, prompt, mul_method, seed)

In [None]:
plot_denoised_iterations(original_pipeline, finetuned_i2i_pipeline, prompt, mul_method, eval_infer_timestep_chkpts)

In [None]:
eval_set_size += 5
generate_ground_truth_eval_set(gpu_ids, eval_set_size, prompt, "original", ablated=False)
generate_ground_truth_eval_set(gpu_ids, eval_set_size, prompt, "ablated", ablated=True)
eval_set_size -= 5

In [None]:
print_image_grid(25, "original")
print_image_grid(25, "ablated")

In [None]:
resnet_model, resnet_feature_extractor = load_model(prompt, mul_method, eval_set_size, cuda_device, image_preprocessor, model_name='resnet')
plot_cases = (
    ("ResNet18 Finetuned (Prefinal-layer)", "Partial Diffusion Ratio", "Cosine Similarity", get_cosine_similarity, None, resnet_feature_extractor),
    ("ResNet18 Finetuned (Prefinal-layer)", "Partial Diffusion Ratio", "Mean Cosine Similarity", get_cosine_similarity_batched, None, resnet_feature_extractor),
    ("ResNet18 prediction softmax", "Partial Diffusion Ratio", "Softmax", get_denoised_img_pred_scores, resnet_model, None),
)
plot_grid(prompt, mul_method, cuda_device, eval_set_size, "resnet", plot_cases)

In [None]:
efficientnet_model, efficientnet_feature_extractor = load_model(prompt, mul_method, eval_set_size, cuda_device, image_preprocessor, model_name='efficientnet')
plot_cases = (
    ("EfficientNet Finetuned (Prefinal-layer)", "Partial Diffusion Ratio", "Cosine Similarity", get_cosine_similarity, None, efficientnet_feature_extractor),
    ("EfficientNet Finetuned (Prefinal-layer)", "Partial Diffusion Ratio", "Mean Cosine Similarity", get_cosine_similarity_batched, None, efficientnet_feature_extractor),
    ("EfficientNet prediction softmax", "Partial Diffusion Ratio", "Softmax", get_denoised_img_pred_scores, efficientnet_model, None),
)
plot_grid(prompt, mul_method, cuda_device, eval_set_size, "efficient", plot_cases)

In [None]:
densenet_model, densenet_feature_extractor = load_model(prompt, mul_method, eval_set_size, cuda_device, image_preprocessor, model_name='densenet')
plot_cases = (
    ("DenseNet Finetuned (Prefinal-layer)", "Partial Diffusion Ratio", "Cosine Similarity", get_cosine_similarity, None, densenet_feature_extractor),
    ("DenseNet Finetuned (Prefinal-layer)", "Partial Diffusion Ratio", "Mean Cosine Similarity", get_cosine_similarity_batched, None, densenet_feature_extractor),
    ("DenseNet prediction softmax", "Partial Diffusion Ratio", "Softmax", get_denoised_img_pred_scores, densenet_model, None),
)
plot_grid(prompt, mul_method, cuda_device, eval_set_size, "dense", plot_cases)

In [None]:
plot_cases = (
    ("KID score trend", "Partial Diffusion Ratio", "KID Score", calculate_kid_score, None, None),
)
plot_grid(prompt, mul_method, cuda_device, eval_set_size, "other-metrics", plot_cases)