<a href="https://colab.research.google.com/github/Con6924/SPM/blob/main/demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Create Configs

### Step 1: Choose Base Model

Examples of `pretrained_sd_model`:
- SD v1.4: `CompVis/stable-diffusion-v1-4`
- SD v1.5: `runwayml/stable-diffusion-v1-5`
- Dreamshaper v8: `stablediffusionapi/dreamshaper-v8`
- SD v2.1: `stabilityai/stable-diffusion-2-1-base`
- WD1.5 beta3: `Birchlabs/wd-1-5-beta3-unofficial`
- SDXL: `stabilityai/stable-diffusion-xl-base-1.0`
- Juggernaut v6: `RunDiffusion/Juggernaut-XL-v6`
- PonyXL: `stablediffusionapi/pony-diffusion-v6-xl`

If base model is v2.x, set `is_v2_model` to `true`.

In [42]:
pretrained_sd_model = r"/workspace/SPM/URPMXL-V6.safetensors"  #@param {type: "string"}
is_v2_model = "false" #@param ["true", "false"]
is_v_prediction_model = "false" #@param ["true", "false"]

### Step 2: Choose Concept

- `target_concept`: Targeted concept for erasing
- `surrogate_concept`: Surrogate concept for defining model generation after erasure, empty string by default

In [43]:
# If you want to train multiple concepts at once in a single SPM, set this to True
multiple_concepts = True #@param {type: "boolean"}

target_concept = ''  #@param {type: "string"}
surrogate_concept = ''  #@param {type: "string"}

# If multiple_concepts is True, the following fields will be used instead
target_concepts = ['teen','child','kid','loli','toddler','preteen','baby','newborn','girl','boy','young girl','young boy','year old','years old','little boy','little girl']  #@param {type: "string"}
surrogate_concepts = ['adult','adult','adult','woman','adult','adult','adult','adult','woman','man','woman','man','20 year old','20 years old','man','woman']  #@param {type: "string"}

### Step 3: SPM Settings

- `mode`: `erase_with_la` or `erase` (for ablation)
- `dim`: default to 1, other values for ablation
- `sampling_batch_size`: indicates how many latent anchors are sampled for each iteration, default to 4, can be reduced if there's not enough VRAM
- `la_strength`: indicates the latent anchoring loss strength that balancing the erasure and preservation, default to 1000 for SD v1.4 models, should be further tuned for other base models for better performance


In [75]:
mode = 'erase_with_la' #@param ["erase_with_la", "erase"]
dim = 2 #@param {type: "number"}
sampling_batch_size = 1 #@param {type: "number"}
la_strength = 400 #@param {type: "number"}
erasing_scale = 2.0 #@param {type: "number"}

### Step 4: Training Settings

There are two parts for training settings, SD settings and optimization settings.
Notice that `resolution` is set to 512 for SD v1.x, 768 for SD v2.x, and 1024 for SDXL.

In [76]:
resolution = 1024 #@param [512, 640, 768, 896, 960, 1024]
max_denoising_steps = 20  #@param {type: "number"}
dynamic_resolution = "true" #@param ["true", "false"]
clip_skip = 2 #@param [1, 2]

batch_size = 1  #@param {type: "number"}
iterations = 5000  #@param {type: "number"}
lr = 15e-5  #@param {type: "number"}
optimizer = "AdamW8bit" #@param {type: "string"}
lr_scheduler = "cosine_with_restarts" #@param {type: "string"}
lr_warmup_steps = 300 #@param {type: "number"}
lr_scheduler_num_cycles = 3 #@param {type: "number"}
save_per_steps = 750  #@param {type: "number"}
precision = "bfloat16" #@param ["float32", "float16", "bfloat16"]
verbose = "false" #@param ["true", "false"]

### Step 5: (Optional) Tracking Training Details with WandB

You can setup your wandb token to track the training details, including training statistics (e.g. losses, learning rates) and visualizations.
Your wandb token can be retrieved from https://wandb.ai/authorize .

In [77]:
wandb_token = "" #@param {type: "string"}

prompts_to_visualize = ["cat", "dog"]  #@param {type: "string"}
generate_num = 2  #@param {type: "number"}

## track target & surrogate by default
prompts_to_visualize = [target_concept, surrogate_concept] + prompts_to_visualize

## login with your wandb token
if wandb_token != "": 
    !wandb login {wandb_token}


### Step 6: Generate Config Files

Run the following code block and the config files are automatically generated.

In [78]:
# you can custom these strings to distiguish your different exps
exp_name = f"sdxl_child_removal_v2"
save_name = f"{exp_name}"
run_name = f"{exp_name}"

config_file_path = f"configs/{save_name}/config.yaml"
prompts_file_path = f"configs/{save_name}/prompt.yaml"


config_file_content = f"""
prompts_file: "{prompts_file_path}"

pretrained_model:
  name_or_path: "{pretrained_sd_model}"
  v2: {is_v2_model}
  v_pred: {is_v_prediction_model}
  clip_skip: {clip_skip}

network:
  rank: {dim}
  alpha: 1.0

train:
  precision: {precision}
  noise_scheduler: "ddim"
  iterations: {iterations}
  batch_size: {batch_size}
  lr: {lr}
  unet_lr: {lr}
  text_encoder_lr: {0.5 * lr}
  optimizer_type: "{optimizer}"
  lr_scheduler: "{lr_scheduler}"
  lr_warmup_steps: 500
  lr_scheduler_num_cycles: 3
  max_denoising_steps: {max_denoising_steps}

save:
  name: "{save_name}"
  path: "output/{save_name}"
  per_steps: {save_per_steps}
  precision: {precision}

logging:
  use_wandb: "false"
  interval: 500
  seed: 0
  generate_num: {generate_num}
  run_name: "{run_name}"
  verbose: {verbose}
  prompts: {prompts_to_visualize}

other:
  use_xformers: true
"""

import os
if not os.path.exists(f"./configs/{save_name}"):
  os.makedirs(f"./configs/{save_name}")

with open(config_file_path, "w") as f:
  f.write(config_file_content)

if multiple_concepts:
  prompts_file_content = "" 

  for target_concept, surrogate_concept in zip(target_concepts, surrogate_concepts):
    prompts_file_content += f"""
- target: "{target_concept}"
  positive: "{target_concept}"
  unconditional: ""
  neutral: "{surrogate_concept}"
  action: "{mode}"
  guidance_scale: "{erasing_scale}"
  resolution: {resolution}
  batch_size: {batch_size}
  dynamic_resolution: {dynamic_resolution}
  la_strength: {la_strength}
  sampling_batch_size: {sampling_batch_size}
  """
else:
  prompts_file_content = f"""
- target: "{target_concept}"
  positive: "{target_concept}"
  unconditional: ""
  neutral: "{surrogate_concept}"
  action: "{mode}"
  guidance_scale: "{erasing_scale}"
  resolution: {resolution}
  batch_size: {batch_size}
  dynamic_resolution: {dynamic_resolution}
  la_strength: {la_strength}
  sampling_batch_size: {sampling_batch_size}
  """

with open(prompts_file_path, "w") as f:
  f.write(prompts_file_content)


## Start Training

In [None]:
!python ./train_spm_xl_mem_reduce.py --config_file {config_file_path}

/workspace/SPM/./train_spm_xl_mem_reduce.py:41: PydanticDeprecatedSince20: The `json` method is deprecated; use `model_dump_json` instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.10/migration/
  "prompts": ",".join([prompt.json() for prompt in prompts]),
/workspace/SPM/./train_spm_xl_mem_reduce.py:42: PydanticDeprecatedSince20: The `json` method is deprecated; use `model_dump_json` instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.10/migration/
  "config": config.json(),
Loading local checkpoint from /workspace/SPM/URPMXL-V6.safetensors
Fetching 17 files: 100%|█████████████████████| 17/17 [00:00<00:00, 42747.70it/s]
Loading pipeline components...:   0%|                     | 0/7 [00:00<?, ?it/s]Some weights of the model checkpoint were not used when initializing CLIPTextModel: 
 ['text_model.embeddings.position_ids']
Loading pipeline com

## Inference

### Prepare SD Pipeline

In [None]:
import torch
from diffusers import DiffusionPipeline
import copy
import gc

def flush():
  torch.cuda.empty_cache()
  gc.collect()

flush()

pipe = DiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4",
    custom_pipeline="lpw_stable_diffusion",
    torch_dtype=torch.float16,
    local_files_only=True,
)

pipe = pipe.to("cuda")
pipe.enable_xformers_memory_efficient_attention()
orig_unet = copy.deepcopy(pipe.unet)


In [None]:
# Generation Configs

prompt = "mickey mouse" #@param {"type": "string"}

negative_prompt = "bad anatomy,watermark,extra digit,signature,worst quality,jpeg artifacts,normal quality,low quality,long neck,lowres,error,blurry,missing fingers,fewer digits,missing arms,text,cropped,Humpbacked ,bad hands,username" #@param {"type": "string"}
width = 512 #@param {type: "number"}
height = 640 #@param {type: "number"}
steps = 20  #@param {type:"slider", min:1, max:50, step:1}
cfg_scale = 7.5 #@param {type:"slider", min:1, max:16, step:0.5}
sample_cnt = 2 #@param {type:"number"}


### Case A: Generate with Single SPM Applied

Here you can set `spm_paths` to trained SPM paths to compare generation behaviours with different SPM applied on SD.

#(`spm_paths`) * `sample_cnt` samples will be displayed.

**Notice that the SPM applied here DOES NOT activates its Facilitated Transport mechanism. For full feature, you may use `infer_spm.py`.**

In [None]:
# SPM Comparison

import matplotlib.pyplot as plt
import random

spm_paths = [
    "output/snoopy/snoopy_last.safetensors",
    "output/pikachu/pikachu_last.safetensors",
]
random_seeds = [random.randint(0, 2**32 - 1) for _ in range(sample_cnt)]

# -------
# w/o SPM

pipe.unet = orig_unet

orig_samples = [pipe.text2img(
        prompt,
        negative_prompt=negative_prompt,
        width=width,
        height=height,
        num_inference_steps=steps,
        guidance_scale=cfg_scale,
        generator=torch.manual_seed(random_seed),
    ).images[0]
    for random_seed in random_seeds
]

# -------
# w/ SPM

spms_samples = []

for spm_path in spm_paths:
    pipe.load_lora_weights(spm_path)
    lora_unet = copy.deepcopy(pipe.unet)
    pipe.unet = lora_unet

    spm_samples = [pipe.text2img(
            prompt,
            negative_prompt=negative_prompt,
            width=width,
            height=height,
            num_inference_steps=steps,
            guidance_scale=cfg_scale,
            generator=torch.manual_seed(random_seed),
        ).images[0]
        for random_seed in random_seeds
    ]
    spms_samples.append(spm_samples)
# ---
spm_cnt = len(spm_paths)
fig, ax = plt.subplots(spm_cnt+1, sample_cnt, figsize=(17, 2*(spm_cnt+1)))

for i in range(sample_cnt):
    ax[0, i].imshow(orig_samples[i])
    ax[0, i].axis('off')

for n in range(spm_cnt):
    for i in range(sample_cnt):
        ax[n+1, i].imshow(spms_samples[n][i])
        ax[n+1, i].axis('off')

plt.subplots_adjust(wspace=0.1, hspace=0.1)

plt.show()



### Case B: Generate with Multiple SPMs Applied

**Notice that the SPM applied here DOES NOT activates its Facilitated Transport mechanism. For full feature, you may use `infer_spm.py`.**

In [None]:
# Multi SPMs Generation

import matplotlib.pyplot as plt
import random
import torch
from diffusers import DiffusionPipeline
import copy
import gc

from src.models.merge_spm import merge_to_sd_model

spm_path_1 = "output/snoopy/snoopy_last.safetensors"
spm_path_2 = "output/pikachu/pikachu_last.safetensors"
ratios = [1.0, 1.0]

prompts = ['snoopy', 'pikachu', 'donald duck', 'woman', '']


random_seeds = [random.randint(0, 2**32 - 1) for _ in range(sample_cnt * len(prompts))]

repeated_prompts = []
_ = [repeated_prompts.extend([prompt] * sample_cnt) for prompt in prompts]

def flush():
  torch.cuda.empty_cache()
  gc.collect()

flush()

samples = []

# -------
# Original

pipe = DiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4",
    custom_pipeline="lpw_stable_diffusion",
    torch_dtype=torch.float32,
    local_files_only=True,
)

pipe = pipe.to("cuda")
pipe.enable_xformers_memory_efficient_attention()
orig_unet = copy.deepcopy(pipe.unet)

orig_samples = [pipe.text2img(
        prompt,
        negative_prompt=negative_prompt,
        width=width,
        height=height,
        num_inference_steps=steps,
        guidance_scale=cfg_scale,
        generator=torch.manual_seed(random_seed),
    ).images[0]
    for (prompt, random_seed) in zip(repeated_prompts, random_seeds)
]
samples.append(orig_samples)

# -------
# Applying spm_1 / spm_2

for spm_path in [spm_path_1, spm_path_2]:
    pipe.load_lora_weights(spm_path)
    lora_unet = copy.deepcopy(pipe.unet)
    pipe.unet = lora_unet

    spm_samples = [pipe.text2img(
            prompt,
            negative_prompt=negative_prompt,
            width=width,
            height=height,
            num_inference_steps=steps,
            guidance_scale=cfg_scale,
            generator=torch.manual_seed(random_seed),
        ).images[0]
        for (prompt, random_seed) in zip(repeated_prompts, random_seeds)
    ]
    samples.append(spm_samples)

# -------
# Applying spm_1 & spm_2

pipe.unet = orig_unet
merge_to_sd_model(pipe.text_encoder, pipe.unet, [spm_path_1, spm_path_2], ratios)

spm_samples = [pipe.text2img(
        prompt,
        negative_prompt=negative_prompt,
        width=width,
        height=height,
        num_inference_steps=steps,
        guidance_scale=cfg_scale,
        generator=torch.manual_seed(random_seed),
    ).images[0]
    for (prompt, random_seed) in zip(repeated_prompts, random_seeds)
]
samples.append(spm_samples)

# -------
# Visualize

fig, ax = plt.subplots(4, sample_cnt * len(prompts), figsize=(17, 2*len(samples)))

def format_yaxis(ax, label):
    ax.axis('on')
    ax.set_ylabel(label)
    ax.get_xaxis().set_visible(False)
    ax.set_yticklabels([])
    ax.tick_params(length=0)
    for spine in ['top', 'bottom', 'left', 'right']:
        ax.spines[spine].set_visible(False)

get_exp_name = lambda x: x[x.rfind('/') + 1: x.rfind('_last.')]
ylabels = ['Original', get_exp_name(spm_path_1), get_exp_name(spm_path_2), 'Both']
for n in range(len(samples)):
    for i, img in enumerate(samples[n]):
        ax[n, i].imshow(img)
        ax[n, i].axis('off')
        if n == 0:
            ax[n, i].set_title(f"{repeated_prompts[i]}_{i % sample_cnt + 1}", va='top', fontsize='small')
        if i == 0:
            format_yaxis(ax[n, i], ylabels[n])


plt.subplots_adjust(wspace=0.1, hspace=0.1)

plt.show()


In [None]:
# Release memory (optional)

del pipe
del orig_unet
del lora_unet
flush()