In [None]:
from utils.utils import *
import torch

In [None]:
iterations = 10

In [None]:
esd_path = f'/home/kyw1654/erasing/models/esd-vangogh_from_vangogh-xattn_1-epochs_{iterations}.pt'
train_method = 'xattn' ## REMEMBER: please use the same train_method you used for training (it is present in the saved name)

device = "cuda:1"
diffuser = StableDiffuser(scheduler='DDIM').to(device)

finetuner = FineTunedModel(diffuser, train_method=train_method)
finetuner.load_state_dict(torch.load(esd_path))

## Original Model

In [None]:
seed = 41
prompt = "Van Gogh"
images = diffuser(prompt,
         img_size=512,
         n_steps=50,
         n_imgs=1,
         generator=torch.Generator().manual_seed(seed),
         guidance_scale=7.5
         )[0][0]
images

## Erased Model (Full fine-tuning)

In [None]:
with finetuner:
    images = diffuser(prompt,
             img_size=512,
             n_steps=50,
             n_imgs=1,
             generator=torch.Generator().manual_seed(seed),
             guidance_scale=7.5
             )[0][0]
images

## Erased Model (LoRA Fine-tuning)

In [None]:
esd_path = f'/home/kyw1654/erasing/models/esd-vangogh_from_vangogh-xattn_1-epochs_{iterations}_lora.pt'

finetuner = FineTunedModel.from_checkpoint(model=diffuser,
                                           checkpoint=esd_path,
                                           train_method='xattn',
                                           lora_rank=4,
                                           lora_alpha=1.0,
                                           )

with finetuner:
    images = diffuser(prompt,
             img_size=512,
             n_steps=50,
             n_imgs=1,
             generator=torch.Generator().manual_seed(seed),
             guidance_scale=7.5
             )[0][0]
images

## SLoU (Steered Low-rank Unlearning)

In [None]:
esd_path = f'/home/kyw1654/erasing/models/esd-vangogh_from_vangogh-xattn_1-epochs_{iterations}_lora_init.pt'

finetuner = FineTunedModel.from_checkpoint(model=diffuser,
                                           checkpoint=esd_path,
                                           train_method='xattn',
                                           lora_rank=4,
                                           lora_alpha=1.0,
                                           lora_init_prompt=None,
                                           )

with finetuner:
    images = diffuser(prompt,
             img_size=512,
             n_steps=50,
             n_imgs=1,
             generator=torch.Generator().manual_seed(seed),
             guidance_scale=7.5
             )[0][0]
images

In [None]:
state_dict = torch.load(esd_path)

In [None]:
original_state = diffuser.unet.state_dict()

In [None]:
# len(original_state['d'])

In [None]:
names = [] 
changes = []
for key, value in state_dict.items():
    if key.split("_")[0] != "lora":
        original_value = original_state[f"{key.replace('unet.','')}.weight"]
        edited_value = value['weight'].to(device)

        change = (edited_value - original_value).norm()
        
        changes.append((change / original_value.norm()).item())
        names.append(key)

In [None]:
def plot_top_k(names, values, k=3):
    # Sort and get top k
    paired_lists = list(zip(names, values))
    sorted_pairs = sorted(paired_lists, key=lambda x: x[1], reverse=True)[:k]
    sorted_names, sorted_values = zip(*sorted_pairs)
    
    # Create bar plot
    plt.figure(figsize=(10, 6))
    plt.bar(sorted_names, sorted_values)
    plt.title(f'Top {k} Values')
    plt.xlabel('Names')
    plt.ylabel('Values')
    plt.show()

In [None]:
plot_top_k(names, changes, k = 40)

In [None]:
# names_sorted[:10]