In [1]:
import random
import warnings

warnings.filterwarnings("ignore")

import yaml

import pandas as pd

import torch
from diffusers import UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTokenizer

#
name = "erase4"
version = "stabilityai/stable-diffusion-2-1-base"

with open(f"600.yaml", 'r', encoding='utf-8') as file:
    data600 = yaml.safe_load(file)

with open(f"300.yaml", 'r', encoding='utf-8') as file:
    data300 = yaml.safe_load(file)

prompt_name = "artist1734_1"
df = pd.read_csv(f"prompt/{prompt_name}.csv")
# prompt = ["An image in the style of " + p for p in df.prompt]
prompt = set(df.prompt)
new_artist_set = set(data300["artist_list2"] + data600["artist_list2"])
old_artist_list = list(prompt - new_artist_set)
new_artist_list = list(new_artist_set)
prompt = ["An image in the style of " + p for p in new_artist_list]

prev_prompt = prompt[:10]
new_prompt = ["art"] * 10
retain_prompt = prompt[10:]

lamb = 0.5
erase_scale = 1
preserve_scale = 0.1
with_key = True

seed = [random.randint(0, 5000) for _ in prompt]

prompt_count = 20
sample_count = 5

config = {
    "version": version,
    
    "prev_prompt": prev_prompt,
    "new_prompt": new_prompt,
    "retain_prompt": retain_prompt,
    
    "lamb": lamb,
    "erase_scale": erase_scale,
    "preserve_scale": preserve_scale, 
    "with_key": with_key,

    "seed": seed,

    "prompt_count": prompt_count,
    "sample_count": sample_count
}

with open(f"data/{name}.yaml", 'w') as file:
    yaml.dump(config, file)

#
@torch.no_grad()
def erase_unet(name):

    with open(f"data/{name}.yaml", 'r', encoding='utf-8') as file:
        config = yaml.safe_load(file)
    
    version = config["version"]
    prev_prompt = config["prev_prompt"]
    new_prompt = config["new_prompt"]
    retain_prompt = config["retain_prompt"]

    lamb = config["lamb"]
    erase_scale = config["erase_scale"]
    preserve_scale = config["preserve_scale"]
    with_key = config["with_key"]

    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    unet = UNet2DConditionModel.from_pretrained(version, subfolder="unet").to(device)
    tokenizer = CLIPTokenizer.from_pretrained(version, subfolder="tokenizer")
    text_encoder = CLIPTextModel.from_pretrained(version, subfolder="text_encoder").to(device)

    ca_layer = []
    for n, module in unet.named_modules():
        if n[-5:] != "attn2": continue
        ca_layer.append(module)

    value_layer = [layer.to_v for layer in ca_layer]
    target_layer = value_layer

    if with_key:
        key_layer = [layer.to_k for layer in ca_layer]
        target_layer += key_layer
    
    prev_token = tokenizer(prev_prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt").input_ids.to(device)
    prev_embd = text_encoder(prev_token)[0].permute(0, 2, 1)
    
    new_token = tokenizer(new_prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt").input_ids.to(device)
    new_embd = text_encoder(new_token)[0].permute(0, 2, 1)

    m2 = (prev_embd @ prev_embd.permute(0, 2, 1)).sum(0) * erase_scale
    m2 += lamb * torch.eye(m2.shape[0], device=device)

    m3 = (new_embd @ prev_embd.permute(0, 2, 1)).sum(0) * erase_scale
    m3 += lamb * torch.eye(m3.shape[0], device=device)

    if retain_prompt:

        retain_token = tokenizer(retain_prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt").input_ids.to(device)
        retain_embd = text_encoder(retain_token)[0].permute(0, 2, 1)

        m2 += (retain_embd @ retain_embd.permute(0, 2, 1)).sum(0) * preserve_scale
        m3 += (retain_embd @ retain_embd.permute(0, 2, 1)).sum(0) * preserve_scale

    for layer in target_layer:
        m1 = layer.weight @ m3
        layer.weight = torch.nn.Parameter((m1 @ torch.inverse(m2)).detach())

    torch.save(unet.state_dict(), f"model/{name}.pth")

erase_unet(name)

In [51]:
new_artist_list

['Alexandre Calame',
 'Meryl McMaster',
 'Jun Kaneko',
 'Harry Clarke',
 'James Tissot',
 'Alex Garant',
 'Shepard Fairey',
 'Jim Mahfood',
 'Auguste Herbin',
 'Wassily Kandinsky',
 'Julie Mehretu',
 'David Bowie',
 'Zhichao Cai',
 'Albert Edelfelt',
 'Shintaro Kago',
 'László Moholy-Nagy',
 'Tadao Ando',
 'Vincent Van Gogh',
 'Anne-Louis Girodet',
 'Richard Corben',
 'Steve Henderson',
 'Rodríguez ARS',
 'Frederick McCubbin',
 'Thomas Dodd',
 'Don Bluth',
 'Tomer Hanuka',
 'Raja Ravi Varma',
 'Edwin Henry Landseer',
 'Jan van Kessel the Elder',
 'Archibald Thorburn',
 'Alexej von Jawlensky',
 'Marina Abramović',
 'Augustus Edwin Mulready',
 'Charles E. Burchfield',
 'Romero Britto',
 'Martin Ansin',
 'Martiros Saryan',
 'Jamini Roy',
 'Walter Langley',
 'Gerhard Munthe',
 'Wendy Froud',
 'Albert Goodwin',
 'Peter Doig',
 'Juan Gris',
 'Charles Addams',
 'Andreas Vesalius',
 'Ilya Repin',
 'Eric Fischl',
 'Anna and Elena Balbusso',
 'Yves Tanguy',
 'Larry Sultan',
 'Catrin Welz-Stein',

In [49]:
(0.288+0.265+0.281+0.243+0.312+0.306+0.279+0.303+0.291+0.255)

2.823

In [None]:
import random
import warnings

warnings.filterwarnings("ignore")

import yaml

import pandas as pd

import torch
from diffusers import UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTokenizer

#
name = "default"
version = "stabilityai/stable-diffusion-2-1-base"

prompt_name = "artist1734_1"
df = pd.read_csv(f"prompt/{prompt_name}.csv")
prompt = ["An image in the style of " + p for p in df.prompt]
prev_prompt = prompt[:10]
new_prompt = ["art"] * 10
retain_prompt = prompt[10:]

lamb = 0.5
erase_scale = 1
preserve_scale = 0.1
with_key = True

seed = [random.randint(0, 5000) for _ in prompt]

prompt_count = 20
sample_count = 5

config = {
    "version": version,
    
    "prev_prompt": prev_prompt,
    "new_prompt": new_prompt,
    "retain_prompt": retain_prompt,
    
    "lamb": lamb,
    "erase_scale": erase_scale,
    "preserve_scale": preserve_scale, 
    "with_key": with_key,

    "seed": seed,

    "prompt_count": prompt_count,
    "sample_count": sample_count
}

with open(f"data/{name}.yaml", 'w') as file:
    yaml.dump(config, file)

device = "cuda:0" if torch.cuda.is_available() else "cpu"
unet = UNet2DConditionModel.from_pretrained(version, subfolder="unet").to(device)
torch.save(unet.state_dict(), f"model/{name}.pth")