In [None]:
from utils.utils import *
import torch
from PIL import Image

In [None]:
iterations = 20
seed = 41
lora_rank = 8
erased_prompt = "VanGogh".lower()
prompt = "Van Gogh style dog"
remain_prompt = "A dog"
device = "cuda:2"
train_method = "xattn"

In [None]:
esd_path = f'./models/esd-{erased_prompt}_from_{erased_prompt}-{train_method}_1-epochs_{iterations}.pt'

diffuser = StableDiffuser(scheduler='DDIM').to(device)

finetuner = FineTunedModel(diffuser, train_method=train_method, lora_rank=lora_rank)
finetuner.load_state_dict(torch.load(esd_path))

## Original Model

In [None]:
forget_origin_image = diffuser(prompt,
         img_size=512,
         n_steps=50,
         n_imgs=1,
         generator=torch.Generator().manual_seed(seed),
         guidance_scale=7.5
         )[0][0]
forget_origin_image


In [None]:
retain_origin_image = diffuser(remain_prompt,
         img_size=512,
         n_steps=50,
         n_imgs=1,
         generator=torch.Generator().manual_seed(seed),
         guidance_scale=7.5
         )[0][0]
retain_origin_image

## Erased Model (Full fine-tuning)

In [None]:
with finetuner:
    forget_image = diffuser(prompt,
             img_size=512,
             n_steps=50,
             n_imgs=1,
             generator=torch.Generator().manual_seed(seed),
             guidance_scale=7.5
             )[0][0]
forget_image

In [None]:
with finetuner:
    retain_image = diffuser(remain_prompt,
             img_size=512,
             n_steps=50,
             n_imgs=1,
             generator=torch.Generator().manual_seed(seed),
             guidance_scale=7.5
             )[0][0]
retain_image

## Erased Model (LoRA Fine-tuning)

In [None]:
esd_path = f'./models/esd-{erased_prompt}_from_{erased_prompt}-{train_method}_1-epochs_{iterations}_lora_rank_{lora_rank}.pt'

finetuner = FineTunedModel.from_checkpoint(model=diffuser,
                                           checkpoint=esd_path,
                                           train_method=train_method,
                                           lora_rank=lora_rank,
                                           lora_alpha=1.0,
                                           )

with finetuner:
    forget_lora_image = diffuser(prompt,
             img_size=512,
             n_steps=50,
             n_imgs=1,
             generator=torch.Generator().manual_seed(seed),
             guidance_scale=7.5
             )[0][0]
forget_lora_image

In [None]:
with finetuner:
    retain_lora_image = diffuser(remain_prompt,
             img_size=512,
             n_steps=50,
             n_imgs=1,
             generator=torch.Generator().manual_seed(seed),
             guidance_scale=7.5
             )[0][0]
retain_lora_image

## SLoU (Steered Low-rank Unlearning)

In [None]:
esd_path = f'./models/esd-{erased_prompt}_from_{erased_prompt}-{train_method}_1-epochs_{iterations}_lora_rank_{lora_rank}_init.pt'

finetuner = FineTunedModel.from_checkpoint(model=diffuser,
                                           checkpoint=esd_path,
                                           train_method=train_method,
                                           lora_rank=lora_rank,
                                           lora_alpha=1.0,
                                           lora_init_prompt=prompt,
                                           )

with finetuner:
    forget_init_image = diffuser(prompt,
             img_size=512,
             n_steps=50,
             n_imgs=1,
             generator=torch.Generator().manual_seed(seed),
             guidance_scale=7.5
             )[0][0]
forget_init_image

In [None]:
with finetuner:
    retain_init_image = diffuser(remain_prompt,
             img_size=512,
             n_steps=50,
             n_imgs=1,
             generator=torch.Generator().manual_seed(seed),
             guidance_scale=7.5
             )[0][0]
retain_init_image

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(15, 6))

plt.subplot(2, 4, 1)
plt.imshow(forget_origin_image)
plt.title('Original Model')
plt.axis('off')

plt.subplot(2, 4, 2)
plt.imshow(forget_image)
plt.title('Erased Model (Full)')
plt.axis('off')

plt.subplot(2, 4, 3)
plt.imshow(forget_lora_image)
plt.title('Erased Model (LoRA)')
plt.axis('off')

plt.subplot(2, 4, 4)
plt.imshow(forget_init_image)
plt.title('SLoU')
plt.axis('off')

plt.subplot(2, 4, 5)
plt.imshow(retain_origin_image)
plt.axis('off')

plt.subplot(2, 4, 6)
plt.imshow(retain_image)
plt.axis('off')

plt.subplot(2, 4, 7)
plt.imshow(retain_lora_image)
plt.axis('off')

plt.subplot(2, 4, 8)
plt.imshow(retain_init_image)
plt.axis('off')

plt.tight_layout()
plt.show()

In [None]:
state_dict = torch.load(esd_path)

In [None]:
original_state = diffuser.unet.state_dict()

In [None]:
# len(original_state['d'])

In [None]:
names = [] 
changes = []
for key, value in state_dict.items():
    if key.split("_")[0] != "lora":
        original_value = original_state[f"{key.replace('unet.','')}.weight"]
        edited_value = value['weight'].to(device)

        change = (edited_value - original_value).norm()
        
        changes.append((change / original_value.norm()).item())
        names.append(key)

In [None]:
def plot_top_k(names, values, k=3):
    # Sort and get top k
    paired_lists = list(zip(names, values))
    sorted_pairs = sorted(paired_lists, key=lambda x: x[1], reverse=True)[:k]
    sorted_names, sorted_values = zip(*sorted_pairs)
    
    # Create bar plot
    plt.figure(figsize=(10, 6))
    plt.bar(sorted_names, sorted_values)
    plt.title(f'Top {k} Values')
    plt.xlabel('Names')
    plt.ylabel('Values')
    plt.show()

In [None]:
plot_top_k(names, changes, k = 2)