In [1]:
import os
import random
import warnings

warnings.filterwarnings("ignore")

import lpips

from tqdm.auto import tqdm

import numpy as np
import pandas as pd

import torch
from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler
from transformers import CLIPProcessor, CLIPModel
from torchvision import transforms

from PIL import Image


df = pd.read_csv("../data/artists1734_prompts.csv")

artists = list(df.artist.unique())
random.shuffle(artists)

prompts = artists[:10]
seeds = [random.randint(0, 5000) for _ in prompts]

In [2]:
def delete_pipeline(pipeline):
    del pipeline.vae
    del pipeline.tokenizer
    del pipeline.text_encoder
    del pipeline.unet
    del pipeline
    torch.cuda.empty_cache()

@torch.no_grad()
def generate_autohalf_images(pipeline, prompts, seeds):
    device = pipeline.device
    images = []
    
    with torch.autocast(device_type="cuda"):
        for idx, prompt in enumerate(tqdm(prompts)):

            scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
            scheduler.set_timesteps(100)

            prompt = [prompt] + [""]
            
            token = pipeline.tokenizer(
                prompt,
                padding="max_length",
                max_length=77,
                truncation=True,
                return_tensors="pt"
            ).input_ids.to(device)
            
            embd = pipeline.text_encoder(token)[0]
            
            seed = seeds[idx]
            generator = torch.Generator(device=device).manual_seed(seed)
            
            latent = torch.randn((1, 4, 64, 64), generator=generator, device=device, dtype=torch.float16)
            latent *= scheduler.init_noise_sigma
            
            for t in scheduler.timesteps:
                latent_input = torch.cat([latent] * 2)
                latent_input = scheduler.scale_model_input(latent_input, timestep=t)
                
                noise = pipeline.unet(latent_input, t, encoder_hidden_states=embd).sample
                cond_noise, uncond_noise = noise.chunk(2)
                noise = uncond_noise + 7.5 * (cond_noise - uncond_noise)
                
                latent = scheduler.step(noise, t, latent).prev_sample
            
            image = pipeline.vae.decode(latent).sample
            image = ((image + 1) / 2).clamp(0, 1).permute(0, 2, 3, 1)
            image = image.detach().cpu().numpy()
            image = (image * 255).round().astype("uint8")
            image = Image.fromarray(image[0])
            images.append(image)
    
    return images

@torch.no_grad()
def generate_half_images(pipeline, prompts, seeds):
    device = pipeline.device
    images = []
    
    for idx, prompt in enumerate(tqdm(prompts)):

        scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
        scheduler.set_timesteps(100)

        prompt = [prompt] + [""]
        
        token = pipeline.tokenizer(
            prompt,
            padding="max_length",
            max_length=77,
            truncation=True,
            return_tensors="pt"
        ).input_ids.to(device)
        
        embd = pipeline.text_encoder(token)[0]
        
        seed = seeds[idx]
        generator = torch.Generator(device=device).manual_seed(seed)
        
        latent = torch.randn((1, 4, 64, 64), generator=generator, device=device, dtype=torch.float16)
        latent *= scheduler.init_noise_sigma
        
        for t in scheduler.timesteps:
            latent_input = torch.cat([latent] * 2)
            latent_input = scheduler.scale_model_input(latent_input, timestep=t)
            
            noise = pipeline.unet(latent_input, t, encoder_hidden_states=embd).sample
            cond_noise, uncond_noise = noise.chunk(2)
            noise = uncond_noise + 7.5 * (cond_noise - uncond_noise)
            
            latent = scheduler.step(noise, t, latent).prev_sample
        
        image = pipeline.vae.decode(latent).sample
        image = ((image + 1) / 2).clamp(0, 1).permute(0, 2, 3, 1)
        image = image.detach().cpu().numpy()
        image = (image * 255).round().astype("uint8")
        image = Image.fromarray(image[0])
        images.append(image)
    
    return images

@torch.no_grad()
def generate_images(pipeline, prompts, seeds):
    device = pipeline.device
    images = []
    
    for idx, prompt in enumerate(tqdm(prompts)):

        scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
        scheduler.set_timesteps(100)

        prompt = [prompt] + [""]
        
        token = pipeline.tokenizer(
            prompt,
            padding="max_length",
            max_length=77,
            truncation=True,
            return_tensors="pt"
        ).input_ids.to(device)
        embd = pipeline.text_encoder(token)[0]
        
        seed = seeds[idx]
        generator = torch.Generator(device=device).manual_seed(seed)
        
        latent = torch.randn((1, 4, 64, 64), generator=generator, device=device)
        latent *= scheduler.init_noise_sigma
        
        for t in scheduler.timesteps:
            latent_input = torch.cat([latent] * 2)
            latent_input = scheduler.scale_model_input(latent_input, timestep=t)
            
            noise = pipeline.unet(latent_input, t, encoder_hidden_states=embd).sample
            cond_noise, uncond_noise = noise.chunk(2)
            noise = uncond_noise + 7.5 * (cond_noise - uncond_noise)
            
            latent = scheduler.step(noise, t, latent).prev_sample
        
        image = pipeline.vae.decode(latent).sample
        image = ((image + 1) / 2).clamp(0, 1).permute(0, 2, 3, 1)
        image = image.detach().cpu().numpy()
        image = (image * 255).round().astype("uint8")
        image = Image.fromarray(image[0])
        images.append(image)
    
    return images

def preprocess_images(images):
    preprocess = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ])

    images_ = []
    for image in images:
        image = preprocess(image)
        images_.append(image)
    images_ = torch.stack(images_)

    return images_

@torch.no_grad()
def compare_images(images1, images2, prompts):

    model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
    processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

    images1 = preprocess_images(images1)
    images2 = preprocess_images(images2)

    inputs = processor(text=prompts, images=images1, return_tensors="pt", padding=True, do_rescale=False)
    outputs = model(**inputs)
    image_embds1 = outputs.image_embeds

    inputs = processor(text=prompts, images=images2, return_tensors="pt", padding=True, do_rescale=False)
    outputs = model(**inputs)
    image_embds2 = outputs.image_embeds

    text_embds = outputs.text_embeds

    clip_score1 = torch.nn.functional.cosine_similarity(image_embds1, text_embds).numpy().round(3)
    clip_score2 = torch.nn.functional.cosine_similarity(image_embds2, text_embds).numpy().round(3)

    loss_function = lpips.LPIPS(net='alex')
    images1 = images1 * 2 - 1
    images2 = images2 * 2 - 1

    return pd.DataFrame({"CLIP 1": clip_score1, "CLIP 2": clip_score2,
                         "CLIP diff": clip_score1 - clip_score2,
                         "LPIPS diff": loss_function(images1, images2).squeeze().detach().numpy().round(3)})


In [None]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
pipeline = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base", torch_dtype=torch.float16).to(device)

autohalf_images = generate_autohalf_images(pipeline, prompts, seeds)

delete_pipeline(pipeline)

pipeline = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base", torch_dtype=torch.float16).to(device)

half_images = generate_half_images(pipeline, prompts, seeds)

delete_pipeline(pipeline)

pipeline = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base").to(device)

images = generate_images(pipeline, prompts, seeds)

delete_pipeline(pipeline)

df1 = compare_images(autohalf_images, half_images, prompts)
df2 = compare_images(autohalf_images, images, prompts)
df3 = compare_images(half_images, images, prompts)

In [None]:
retain_prompts = artists[5:]
prev_prompts = prompts[:5]
new_prompts = ["art"] * 5

@torch.no_grad()
def erase_pipeline(pipeline, prev_prompts, new_prompts, retain_prompts):

    device = pipeline.device

    ca_layers = []
    for name, module in pipeline.unet.named_modules():
        if name[-5:] != "attn2": continue
        ca_layers.append(module)

    value_layers = [layer.to_v for layer in ca_layers]
    target_layers = value_layers

    key_layers = [layer.to_k for layer in ca_layers]
    target_layers += key_layers
    
    prev_tokens = pipeline.tokenizer(prev_prompts, padding="max_length", max_length=77, truncation=True, return_tensors="pt").input_ids.to(device)
    prev_embds = pipeline.text_encoder(prev_tokens)[0].permute(0, 2, 1)
    
    new_tokens = pipeline.tokenizer(new_prompts, padding="max_length", max_length=77, truncation=True, return_tensors="pt").input_ids.to(device)
    new_embds = pipeline.text_encoder(new_tokens)[0].permute(0, 2, 1)

    lamb = 0.5
    erase_scale = 1
    preserve_scale = 0.1

    m2 = (prev_embds @ prev_embds.permute(0, 2, 1)).sum(0) * erase_scale
    m2 += lamb * torch.eye(m2.shape[0], device=device)

    m3 = (new_embds @ prev_embds.permute(0, 2, 1)).sum(0) * erase_scale
    m3 += lamb * torch.eye(m3.shape[0], device=device)

    retain_tokens = pipeline.tokenizer(retain_prompts, padding="max_length", max_length=77, truncation=True, return_tensors="pt").input_ids.to(device)
    retain_embds = pipeline.text_encoder(retain_tokens)[0].permute(0, 2, 1)

    m2 += (retain_embds @ retain_embds.permute(0, 2, 1)).sum(0) * preserve_scale
    m3 += (retain_embds @ retain_embds.permute(0, 2, 1)).sum(0) * preserve_scale

    for target_layer in target_layers:
        m1 = target_layer.weight @ m3
        target_layer.weight = torch.nn.Parameter((m1 @ torch.inverse(m2)).detach())

    return pipeline

device = "cuda:0" if torch.cuda.is_available() else "cpu"
pipeline = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base").to(device)

erased_pipeline = erase_pipeline(pipeline, prev_prompts, new_prompts, retain_prompts)

erased_images = generate_images(erased_pipeline, prompts, seeds)

erased_pipeline.text_encoder.half()
erased_pipeline.vae.half()
erased_pipeline.unet.half()

erased_half_images = generate_half_images(erased_pipeline, prompts, seeds)
erased_autohalf_images = generate_autohalf_images(erased_pipeline, prompts, seeds)

df4 = compare_images(erased_autohalf_images, erased_half_images, prompts)
df5 = compare_images(erased_autohalf_images, erased_images, prompts)
df6 = compare_images(erased_half_images, erased_images, prompts)

In [15]:
df4

Unnamed: 0,CLIP 1,CLIP 2,CLIP diff,LPIPS diff
0,0.226,0.228,-0.002,0.002
1,0.204,0.203,0.001,0.001
2,0.211,0.211,0.0,0.001
3,0.199,0.196,0.003,0.001
4,0.147,0.148,-0.001,0.0
5,0.274,0.274,0.0,0.0
6,0.21,0.212,-0.002,0.002
7,0.244,0.246,-0.002,0.006
8,0.296,0.297,-0.001,0.0
9,0.169,0.168,0.001,0.002


In [16]:
df5

Unnamed: 0,CLIP 1,CLIP 2,CLIP diff,LPIPS diff
0,0.226,0.227,-0.001,0.002
1,0.204,0.202,0.002,0.002
2,0.211,0.211,0.0,0.003
3,0.199,0.196,0.003,0.004
4,0.147,0.147,0.0,0.001
5,0.274,0.273,0.001,0.001
6,0.21,0.214,-0.004,0.005
7,0.244,0.242,0.002,0.002
8,0.296,0.297,-0.001,0.001
9,0.169,0.168,0.001,0.007


In [17]:
df6

Unnamed: 0,CLIP 1,CLIP 2,CLIP diff,LPIPS diff
0,0.228,0.227,0.001,0.001
1,0.203,0.202,0.001,0.004
2,0.211,0.211,0.0,0.003
3,0.196,0.196,0.0,0.004
4,0.148,0.147,0.001,0.001
5,0.274,0.273,0.001,0.001
6,0.212,0.214,-0.002,0.007
7,0.246,0.242,0.004,0.008
8,0.297,0.297,0.0,0.001
9,0.168,0.168,0.0,0.008
