In [None]:
import os
import random
import warnings

warnings.filterwarnings("ignore")

from tqdm.auto import tqdm

import numpy as np
import pandas as pd

import torch
from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler
from transformers import CLIPProcessor, CLIPModel
from torchvision import transforms

from PIL import Image


df = pd.read_csv("../data/artists1734_prompts.csv")

artists = list(df.artist.unique())
random.shuffle(artists)

prompts = artists[:20]
seeds = [random.randint(0, 5000) for _ in prompts]

In [5]:
@torch.no_grad()
def generate_images(pipeline, prompts, seeds):

    device = pipeline.unet.device

    images = []
    for idx, prompt in enumerate(tqdm(prompts)):

        scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
        scheduler.set_timesteps(100)

        prompt = [prompt] + [""]

        token = pipeline.tokenizer(prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt").input_ids
        embd = pipeline.text_encoder(token)[0].to(device)

        seed = seeds[idx]
        generator = torch.Generator()
        generator.manual_seed(seed)
        latent = torch.randn((1, 4, 64, 64), generator=generator).to(device)
        latent *= scheduler.init_noise_sigma

        for t in scheduler.timesteps:

            latent_input = torch.cat([latent] * 2)
            latent_input = scheduler.scale_model_input(latent_input, timestep=t)

            noise = pipeline.unet(latent_input, t, encoder_hidden_states=embd).sample
            cond_noise, uncond_noise = noise.chunk(2)
            noise = uncond_noise + 7.5 * (cond_noise - uncond_noise)

            latent = scheduler.step(noise, t, latent).prev_sample

        latent /= 0.18215
        image = pipeline.vae.decode(latent).sample
        image = ((image + 1) / 2).clamp(0, 1).permute(0, 2, 3, 1)
        image = image.detach().cpu().numpy()
        image = (image * 255).round().astype("uint8")
        image = Image.fromarray(image[0])
        images.append(image)
    return images

device = "cuda:0" if torch.cuda.is_available() else "cpu"
pipeline = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base")
pipeline.unet.to(device)
pipeline.vae.to(device)

images = generate_images(pipeline, prompts, seeds)

def delete_pipeline(pipeline):
    del pipeline.vae
    del pipeline.tokenizer
    del pipeline.text_encoder
    del pipeline.unet
    del pipeline
    torch.cuda.empty_cache()

delete_pipeline(pipeline)

Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

In [None]:
@torch.no_grad()
def generate_images2(pipeline, prompts, seeds):

    device = pipeline.unet.device

    images = []
    for idx, prompt in enumerate(tqdm(prompts)):

        scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
        scheduler.set_timesteps(100)

        prompt = [prompt] + [""]

        token = pipeline.tokenizer(prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt").input_ids
        embd = pipeline.text_encoder(token)[0].to(device).half()

        seed = seeds[idx]
        generator = torch.Generator()
        generator.manual_seed(seed)
        latent = torch.randn((1, 4, 64, 64), generator=generator).to(device).half()
        latent *= scheduler.init_noise_sigma

        for t in scheduler.timesteps:

            latent_input = torch.cat([latent] * 2)
            latent_input = scheduler.scale_model_input(latent_input, timestep=t)

            noise = pipeline.unet(latent_input, t, encoder_hidden_states=embd).sample
            cond_noise, uncond_noise = noise.chunk(2)
            noise = uncond_noise + 7.5 * (cond_noise - uncond_noise)

            latent = scheduler.step(noise, t, latent).prev_sample

        latent /= 0.18215
        image = pipeline.vae.decode(latent).sample
        image = ((image + 1) / 2).clamp(0, 1).permute(0, 2, 3, 1)
        image = image.detach().cpu().numpy()
        image = (image * 255).round().astype("uint8")
        image = Image.fromarray(image[0])
        images.append(image)
    return images

device = "cuda:0" if torch.cuda.is_available() else "cpu"
pipeline = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base")
pipeline.unet.to(device).half()
pipeline.vae.to(device).half()

images2 = generate_images2(pipeline, prompts, seeds)

Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

In [80]:
retain_prompts = artists[10:100]
prev_prompts = prompts[:10]
new_prompts = ["art"] * 10

@torch.no_grad()
def erase_pipeline(pipeline, prev_prompts, new_prompts, retain_prompts):

    device = pipeline.unet.device

    ca_layers = []
    for name, module in pipeline.unet.named_modules():
        if name[-5:] != "attn2": continue
        ca_layers.append(module)

    value_layers = [layer.to_v for layer in ca_layers]
    target_layers = value_layers

    key_layers = [layer.to_k for layer in ca_layers]
    target_layers += key_layers
    
    pre_tokens = pipeline.tokenizer(prev_prompts, padding="max_length", max_length=77, truncation=True, return_tensors="pt").input_ids
    prev_embds = pipeline.text_encoder(pre_tokens)[0].permute(0, 2, 1).to(device)
    
    new_tokens = pipeline.tokenizer(new_prompts, padding="max_length", max_length=77, truncation=True, return_tensors="pt").input_ids
    new_embds = pipeline.text_encoder(new_tokens)[0].permute(0, 2, 1).to(device)

    lamb = 0.5
    erase_scale = 1
    preserve_scale = 0.1

    m2 = (prev_embds @ prev_embds.permute(0, 2, 1)).sum(0) * erase_scale
    m2 += lamb * torch.eye(m2.shape[0], device=device)

    m3 = (new_embds @ prev_embds.permute(0, 2, 1)).sum(0) * erase_scale
    m3 += lamb * torch.eye(m3.shape[0], device=device)

    retain_tokens = pipeline.tokenizer(retain_prompts, padding="max_length", max_length=77, truncation=True, return_tensors="pt").input_ids
    retain_embds = pipeline.text_encoder(retain_tokens)[0].permute(0, 2, 1).to(device)

    m2 += (retain_embds @ retain_embds.permute(0, 2, 1)).sum(0) * preserve_scale
    m3 += (retain_embds @ retain_embds.permute(0, 2, 1)).sum(0) * preserve_scale

    for target_layer in target_layers:
        m1 = target_layer.weight @ m3
        target_layer.weight = torch.nn.Parameter((m1 @ torch.inverse(m2)).detach())

    return pipeline

delete_pipeline(pipeline)

device = "cuda:0" if torch.cuda.is_available() else "cpu"
pipeline = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base")
pipeline.unet.to(device)
pipeline.vae.to(device)

pipeline = erase_pipeline(pipeline, prev_prompts, new_prompts, retain_prompts)

images3 = generate_images(pipeline, prompts, seeds)

delete_pipeline(pipeline)

Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

In [75]:
@torch.no_grad()
def erase_pipeline2(pipeline, prev_prompts, new_prompts, retain_prompts):

    device = pipeline.unet.device

    ca_layers = []
    for name, module in pipeline.unet.named_modules():
        if name[-5:] != "attn2": continue
        ca_layers.append(module)

    value_layers = [layer.to_v for layer in ca_layers]
    target_layers = value_layers

    key_layers = [layer.to_k for layer in ca_layers]
    target_layers += key_layers
    
    pre_tokens = pipeline.tokenizer(prev_prompts, padding="max_length", max_length=77, truncation=True, return_tensors="pt").input_ids
    prev_embds = pipeline.text_encoder(pre_tokens)[0].permute(0, 2, 1).to(device)
    
    new_tokens = pipeline.tokenizer(new_prompts, padding="max_length", max_length=77, truncation=True, return_tensors="pt").input_ids
    new_embds = pipeline.text_encoder(new_tokens)[0].permute(0, 2, 1).to(device)

    lamb = 0.5
    erase_scale = 1
    preserve_scale = 0.1

    m2 = (prev_embds @ prev_embds.permute(0, 2, 1)).sum(0) * erase_scale
    m2 += lamb * torch.eye(m2.shape[0], device=device)

    m3 = (new_embds @ prev_embds.permute(0, 2, 1)).sum(0) * erase_scale
    m3 += lamb * torch.eye(m3.shape[0], device=device)

    retain_tokens = pipeline.tokenizer(retain_prompts, padding="max_length", max_length=77, truncation=True, return_tensors="pt").input_ids
    retain_embds = pipeline.text_encoder(retain_tokens)[0].permute(0, 2, 1).to(device)

    m2 += (retain_embds @ retain_embds.permute(0, 2, 1)).sum(0) * preserve_scale
    m3 += (retain_embds @ retain_embds.permute(0, 2, 1)).sum(0) * preserve_scale

    for target_layer in target_layers:
        m1 = target_layer.weight.to(torch.float32) @ m3
        target_layer.weight = torch.nn.Parameter((m1 @ torch.inverse(m2)).to(torch.float16).detach())

    return pipeline

device = "cuda:0" if torch.cuda.is_available() else "cpu"
pipeline = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base")
pipeline.unet.to(device).half()
pipeline.vae.to(device).half()

pipeline = erase_pipeline2(pipeline, prev_prompts, new_prompts, retain_prompts)

images4 = generate_images2(pipeline, prompts, seeds)

Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

In [74]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
pipeline = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base")
pipeline = erase_pipeline(pipeline, prev_prompts, new_prompts, retain_prompts)

pipeline.unet.to(device).half()
pipeline.vae.to(device).half()

images5 = generate_images2(pipeline, prompts, seeds)

Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

In [None]:
def preprocess_images(images, score=None):
    preprocess = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ])

    images_ = []
    for image in images:
        image = preprocess(image)
        images_.append(image)
    images_ = torch.stack(images_)
    if score == "lpips": images_ = images_ * 2 - 1
    elif score == "clip": pass

    return images_

@torch.no_grad()
def measure_LPIPS(images1, images2):

    images1 = preprocess_images(images1, "lpips")
    images2 = preprocess_images(images2, "lpips")
    
    loss_function = lpips.LPIPS(net='alex')
    return loss_function(images1, images2).squeeze().detach().numpy().round(3)

@torch.no_grad()
def measure_CLIP(images, prompts):

    model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
    processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

    images = preprocess_images(images, "clip")
    inputs = processor(text=prompts, images=images, return_tensors="pt", padding=True, do_rescale=False)

    outputs = model(**inputs)
    image_embds = outputs.image_embeds
    text_embds = outputs.text_embeds
    
    return torch.nn.functional.cosine_similarity(image_embds, text_embds).numpy().round(3)

In [81]:
pd.DataFrame({"LPIPS 1-2": measure_LPIPS(images, images2),
              "CLIP 1-2": measure_CLIP(images, prompts) - measure_CLIP(images2, prompts),
              "LPIPS 3-4": measure_LPIPS(images3, images4),
              "CLIP 3-4": measure_CLIP(images3, prompts) - measure_CLIP(images4, prompts),
              "LPIPS 3-5": measure_LPIPS(images3, images5),
              "CLIP 3-5": measure_CLIP(images3, prompts) - measure_CLIP(images5, prompts),
              })

Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: c:\Users\yoonj\AppData\Local\Programs\Python\Python312\Lib\site-packages\lpips\weights\v0.1\alex.pth
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: c:\Users\yoonj\AppData\Local\Programs\Python\Python312\Lib\site-packages\lpips\weights\v0.1\alex.pth
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: c:\Users\yoonj\AppData\Local\Programs\Python\Python312\Lib\site-packages\lpips\weights\v0.1\alex.pth


Unnamed: 0,LPIPS 1-2,CLIP 1-2,LPIPS 3-4,CLIP 3-4,LPIPS 3-5,CLIP 3-5
0,0.0,0.0,0.022,0.005,0.031,0.002
1,0.0,-0.002,0.008,0.001,0.152,-0.007
2,0.026,-0.011,0.01,-0.002,0.087,-0.01
3,0.008,-0.003,0.066,-0.006,0.14,0.028
4,0.0,-0.001,0.027,-0.003,0.097,-0.007
5,0.0,0.0,0.066,-0.002,0.031,0.009
6,0.0,0.001,0.02,-0.003,0.035,-0.008
7,0.003,0.002,0.024,0.006,0.013,-0.004
8,0.003,-0.001,0.009,0.003,0.015,0.0
9,0.005,-0.003,0.196,0.02,0.179,0.003
