In [None]:
from utils.utils import *
import torch

In [None]:
iterations = 10
seed = 41
lora_rank = 32
prompt = "Van Gogh"
device = "cuda:1"
train_method = "full"

In [None]:
esd_path = f'/home/kyw1654/erasing/models/esd-vangogh_from_vangogh-{train_method}_1-epochs_{iterations}.pt'

diffuser = StableDiffuser(scheduler='DDIM').to(device)

finetuner = FineTunedModel(diffuser, train_method=train_method, lora_rank=lora_rank)
finetuner.load_state_dict(torch.load(esd_path))

## Original Model

In [None]:
origin_images = diffuser(prompt,
         img_size=512,
         n_steps=50,
         n_imgs=1,
         generator=torch.Generator().manual_seed(seed),
         guidance_scale=7.5
         )[0][0]
origin_images

## Erased Model (Full fine-tuning)

In [None]:
with finetuner:
    images = diffuser(prompt,
             img_size=512,
             n_steps=50,
             n_imgs=1,
             generator=torch.Generator().manual_seed(seed),
             guidance_scale=7.5
             )[0][0]
images

## Erased Model (LoRA Fine-tuning)

In [None]:
esd_path = f'/home/kyw1654/erasing/models/esd-vangogh_from_vangogh-{train_method}_1-epochs_{iterations}_lora_rank_{lora_rank}.pt'

finetuner = FineTunedModel.from_checkpoint(model=diffuser,
                                           checkpoint=esd_path,
                                           train_method=train_method
                                           lora_rank=lora_rank,
                                           lora_alpha=1.0,
                                           )

with finetuner:
    lora_images = diffuser(prompt,
             img_size=512,
             n_steps=50,
             n_imgs=1,
             generator=torch.Generator().manual_seed(seed),
             guidance_scale=7.5
             )[0][0]
lora_images

## SLoU (Steered Low-rank Unlearning)

In [None]:
esd_path = f'/home/kyw1654/erasing/models/esd-vangogh_from_vangogh-{train_method}_1-epochs_{iterations}_lora_rank_{lora_rank}_init.pt'

finetuner = FineTunedModel.from_checkpoint(model=diffuser,
                                           checkpoint=esd_path,
                                           train_method=train_method,
                                           lora_rank=lora_rank,
                                           lora_alpha=1.0,
                                           lora_init_prompt=None,
                                           )

with finetuner:
    init_images = diffuser(prompt,
             img_size=512,
             n_steps=50,
             n_imgs=1,
             generator=torch.Generator().manual_seed(seed),
             guidance_scale=7.5
             )[0][0]
init_images

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(15, 6))

# Create a 1x4 grid for the images
plt.subplot(1, 4, 1)
plt.imshow(origin_images)
plt.title('Original Model')
plt.axis('off')

plt.subplot(1, 4, 2)
plt.imshow(images)
plt.title('Erased Model (Full)')
plt.axis('off')

plt.subplot(1, 4, 3)
plt.imshow(lora_images)
plt.title('Erased Model (LoRA)')
plt.axis('off')

plt.subplot(1, 4, 4)
plt.imshow(init_images)
plt.title('SLoU')
plt.axis('off')

plt.tight_layout()
plt.show()

In [None]:
state_dict = torch.load(esd_path)

In [None]:
original_state = diffuser.unet.state_dict()

In [None]:
# len(original_state['d'])

In [None]:
names = [] 
changes = []
for key, value in state_dict.items():
    if key.split("_")[0] != "lora":
        original_value = original_state[f"{key.replace('unet.','')}.weight"]
        edited_value = value['weight'].to(device)

        change = (edited_value - original_value).norm()
        
        changes.append((change / original_value.norm()).item())
        names.append(key)

In [None]:
def plot_top_k(names, values, k=3):
    # Sort and get top k
    paired_lists = list(zip(names, values))
    sorted_pairs = sorted(paired_lists, key=lambda x: x[1], reverse=True)[:k]
    sorted_names, sorted_values = zip(*sorted_pairs)
    
    # Create bar plot
    plt.figure(figsize=(10, 6))
    plt.bar(sorted_names, sorted_values)
    plt.title(f'Top {k} Values')
    plt.xlabel('Names')
    plt.ylabel('Values')
    plt.show()

In [None]:
plot_top_k(names, changes, k = 2)

In [None]:
# names_sorted[:10]