## 在SD21上测试PAN攻击

In [1]:
import random
import wandb
import argparse
import copy
import hashlib
import itertools
import logging
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import sys
repo_path = "/data/home/yekai/github/mypro/MetaCloak"
sys.path.append(repo_path)
from pathlib import Path
import datasets
import diffusers
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
from diffusers.utils.import_utils import is_xformers_available
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from tqdm import tqdm
from transformers import AutoTokenizer, PretrainedConfig
from robust_facecloak.model.db_train import  DreamBoothDatasetFromTensor
from robust_facecloak.model.db_train import import_model_class_from_model_name_or_path
from robust_facecloak.generic.data_utils import PromptDataset, load_data
from robust_facecloak.generic.share_args import share_parse_args

import pickle
from copy import deepcopy

  from .autonotebook import tqdm as notebook_tqdm
2024-09-29 20:31:38.846624: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-09-29 20:31:38.846684: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-09-29 20:31:38.848177: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-09-29 20:31:38.854426: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
class myargs():
    def __init__(self):
        self.learning_rate=5e-7
        self.total_trail_num = 4
        self.instance_prompt="a photo of sks person"
        self.class_data_dir=f"{repo_path}/prior-data/SD21base/class-person"
        self.instance_data_dir_for_adversarial = f"{repo_path}/dataset/VGGFace2-clean/0/set_B"
        self.output_dir = "./tmpdata"
        self.class_prompt="a photo of a person"
        self.total_train_steps = 1000
        self.interval = 200
        self.advance_steps = 2
        self.radius = 11
        self.resolution=512
        self.center_crop=True
        self.with_prior_preservation=True
        self.revision = None
        self.prior_loss_weight = 1.0
        self.train_text_encoder = True
        self.enable_xformers_memory_efficient_attention = True
        self.mixed_precision = "bf16"
        self.attack_pgd_random_start = False
        

args = myargs()

### 首先是模型加载和训练代码

In [3]:
def train_few_step(
    args,
    models,
    tokenizer,
    noise_scheduler,
    vae,
    data_tensor: torch.Tensor,
    num_steps=20,
    step_wise_save=False,
    save_step=100, 
    retain_graph=False,
    task_loss_name = None,
    copy_flag = True
):
    # Load the tokenizer
    if copy_flag:
        unet, text_encoder = copy.deepcopy(models[0]), copy.deepcopy(models[1])
    else:
        unet, text_encoder = models[0], models[1]
    print('copy model done')
    # 绑定unet和文本编码器的参数，共同优化
    params_to_optimize = itertools.chain(unet.parameters(), text_encoder.parameters())

    # 设置优化器，优化目标为unet参数和文本编码器参数
    optimizer = torch.optim.AdamW(
        params_to_optimize,
        lr=args.learning_rate,
        betas=(0.9, 0.999),
        weight_decay=1e-2,
        eps=1e-08,
    )

    train_dataset = DreamBoothDatasetFromTensor(
        data_tensor,
        # A photo of sks person
        args.instance_prompt,
        tokenizer,
        args.class_data_dir,
        args.class_prompt,
        args.resolution,
        args.center_crop,
    )

    weight_dtype = torch.bfloat16
    device = torch.device("cuda")

    # 将关键模型移动到对应设备
    vae.to(device, dtype=weight_dtype)
    text_encoder.to(device, dtype=weight_dtype)
    unet.to(device, dtype=weight_dtype)

    
    step2modelstate={}
        
    pbar = tqdm(total=num_steps, desc="training")
    print("Start training...")
    for step in range(num_steps):
        print(f"step: {step}/{num_steps}")
        # print(calculate_model_hash(text_encoder))
        # 根据设置选择是否保存训练中间过程参数
        if step_wise_save and ((step+1) % save_step == 0 or step == 0):
            # make sure the model state dict is put to cpu
            step2modelstate[step] = {
                "unet": copy.deepcopy(unet.cpu().state_dict()),
                "text_encoder": copy.deepcopy(text_encoder.cpu().state_dict()),
            }
            # move the model back to gpu
            unet.to(device, dtype=weight_dtype); text_encoder.to(device, dtype=weight_dtype)
            
        pbar.update(1)
        # 训练模式
        unet.train()
        text_encoder.train()
        # 循环从训练数据集中取一个样本
        step_data = train_dataset[step % len(train_dataset)]
        # 将样本中的类别图片和实例图片整合并移动到设备上
        # print((step_data["instance_images"]))
        # print((step_data["class_images"]))
        pixel_values = torch.stack([step_data["instance_images"].to(device), step_data["class_images"].to(device)]).to(
            device, dtype=weight_dtype
        )
        # 将样本中的类别提示词和实例提示词整合并移动到设备上
        input_ids = torch.cat([step_data["instance_prompt_ids"], step_data["class_prompt_ids"]], dim=0).to(device)
        # 使用VAE对图像进行编码，并对潜在表示进行后处理
        latents = vae.encode(pixel_values).latent_dist.sample()
        latents = latents * vae.config.scaling_factor

        # Sample noise that we'll add to the latents
        # 向图片编码向量（潜在空间向量表示）添加随机噪声
        noise = torch.randn_like(latents)
        # batch_size
        bsz = latents.shape[0]
        # Sample a random timestep for each image
        # 为每个图片生成一个随机step
        timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
        timesteps = timesteps.long()

        # Add noise to the latents according to the noise magnitude at each timestep
        # (this is the forward diffusion process)
        # 前向过程，得到前向扩散特定时间步后的图片的潜在空间向量
        noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

        # Get the text embedding for conditioning
        # 文本编码向量作为条件信息
        encoder_hidden_states = text_encoder(input_ids)[0]
        
        # Predict the noise residual
        # 模型基于当前的噪声潜在表示（noisy_latents）、时间步（timesteps）和文本条件（encoder_hidden_states），预测噪声残差
        model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample

        # Get the target for loss depending on the prediction type
        # 预测的可以是噪声，也可以是变化速度
        if noise_scheduler.config.prediction_type == "epsilon":
            target = noise
        elif noise_scheduler.config.prediction_type == "v_prediction":
            target = noise_scheduler.get_velocity(latents, noise, timesteps)
        else:
            raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

        # with prior preservation loss
        # 可选是否使用先验保留损失
        if args.with_prior_preservation:
            # 再次分为一半一半，对应之前的stack操作
            model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
            target, target_prior = torch.chunk(target, 2, dim=0)

            # Compute instance loss
            instance_loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")

            # Compute prior loss  确保在原来类别上的生成能力不丢失
            prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")

            # Add the prior loss to the instance loss.
            loss = instance_loss + args.prior_loss_weight * prior_loss

        else:
            # 不使用先验保留损失
            loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
        if task_loss_name is not None:
            wandb.log({f"{task_loss_name}": loss.item()})
        # 反向传播
        loss.backward(retain_graph=retain_graph)
        # 梯度裁剪
        torch.nn.utils.clip_grad_norm_(params_to_optimize, 1.0, error_if_nonfinite=True)
        # 参数优化
        optimizer.step()
        optimizer.zero_grad()

    pbar.close()
    # 返回训练的参数数据
    if step_wise_save:
        return [unet, text_encoder], step2modelstate
    else:     
        return [unet, text_encoder]

# 主要模型的加载
def load_model(args, model_path):
    print(model_path)
    # import correct text encoder class
    text_encoder_cls = import_model_class_from_model_name_or_path(model_path, args.revision)

    # Load scheduler and models
    # 文本编码器加载
    text_encoder = text_encoder_cls.from_pretrained(
        model_path,
        subfolder="text_encoder",
        revision=args.revision,
    )
    # unet加载
    unet = UNet2DConditionModel.from_pretrained(model_path, subfolder="unet", revision=args.revision)
    # tokenizer加载
    tokenizer = AutoTokenizer.from_pretrained(
        model_path,
        subfolder="tokenizer",
        revision=args.revision,
        use_fast=False,
    )
    # 使用DDPM同款调度器
    noise_scheduler = DDPMScheduler.from_pretrained(model_path, subfolder="scheduler")
    # 加载预训练的vae，vae不需要更新参数
    vae = AutoencoderKL.from_pretrained(model_path, subfolder="vae", revision=args.revision)

    vae.requires_grad_(False)

    # 甚至可以不更新文本编码器的参数
    if not args.train_text_encoder:
        text_encoder.requires_grad_(False)

    if args.enable_xformers_memory_efficient_attention:
        print("You selected to used efficient xformers")
        print("Make sure to install the following packages before continue")
        print("pip install triton==2.0.0.dev20221031")
        print("pip install pip install xformers==0.0.17.dev461")

        unet.enable_xformers_memory_efficient_attention()
    # 返回5个关键模型
    return text_encoder, unet, tokenizer, noise_scheduler, vae

def save_image(perturbed_data, id_stamp):
    save_folder = f"{args.output_dir}/noise-ckpt/{id_stamp}"
    os.makedirs(save_folder, exist_ok=True)
    noised_imgs = perturbed_data.detach()
    img_names = [
        str(instance_path).split("/")[-1]
        for instance_path in list(Path(args.instance_data_dir_for_adversarial).iterdir())
    ]
    for img_pixel, img_name in zip(noised_imgs, img_names):
        save_path = os.path.join(save_folder, f"noisy_{img_name}")
        Image.fromarray(
            img_pixel.float().detach().cpu().permute(1, 2, 0).numpy().squeeze().astype(np.uint8)
        ).save(save_path)

#### 加载模型

In [4]:
model_paths = ["/data/home/yekai/github/mypro/MetaCloak/SD/stable-diffusion-v1-5"]
num_models = len(model_paths)

In [5]:
MODEL_BANKS = [load_model(args, path) for path in model_paths]
MODEL_STATEDICTS = [
    {
        "text_encoder": MODEL_BANKS[i][0].state_dict(),
        "unet": MODEL_BANKS[i][1].state_dict(),
    }
    for i in range(num_models)
]

You are using a model of type clip_text_model to instantiate a model of type . This is not supported for all configurations of models and can yield errors.


/data/home/yekai/github/mypro/MetaCloak/SD/stable-diffusion-v1-5
model_path:/data/home/yekai/github/mypro/MetaCloak/SD/stable-diffusion-v1-5
You selected to used efficient xformers
Make sure to install the following packages before continue
pip install triton==2.0.0.dev20221031
pip install pip install xformers==0.0.17.dev461


#### 开始初始的预训练

In [6]:
# 加载原始扰动数据
perturbed_data = load_data(
    args.instance_data_dir_for_adversarial,
    # size=args.resolution,
    # center_crop=args.center_crop,
)
original_data= copy.deepcopy(perturbed_data)


init_model_state_pool = {}

In [7]:
pbar = tqdm(total=num_models, desc="initializing models")
# split sub-models
# 对于每一个模型，都进行一次训练
for j in range(num_models):
    init_model_state_pool[j] = {}
    # 提取关键模块
    text_encoder, unet, tokenizer, noise_scheduler, vae = MODEL_BANKS[j]
    
    # 加载unet和text_encoder的模型参数
    unet.load_state_dict(MODEL_STATEDICTS[j]["unet"])
    text_encoder.load_state_dict(MODEL_STATEDICTS[j]["text_encoder"])
    # 打包unet和text_encoder
    f_ori = [unet, text_encoder]
    # 得到训练total_train_steps步之后的unet, text_encoder参数以及中间状态参数
    print("start training model", j)
    f_ori, step2state_dict = train_few_step(
            args,
            f_ori,
            tokenizer,
            noise_scheduler,
            vae,
            perturbed_data.float(),
            args.total_train_steps,
            step_wise_save=True,
            save_step=args.interval,
            task_loss_name=None,
    )  
    # init_model_state_pool就来保存训练中间状态参数
    init_model_state_pool[j] = step2state_dict

    # 释放占用的资源
    del f_ori, unet, text_encoder, tokenizer, noise_scheduler, vae
    import gc
    gc.collect()
    torch.cuda.empty_cache()
    pbar.update(1)
pbar.close()

initializing models:   0%|          | 0/1 [00:00<?, ?it/s]

start training model 0
copy model done




Start training...
step: 0/1000




step: 1/1000




step: 2/1000




step: 3/1000




step: 4/1000




step: 5/1000




step: 6/1000




step: 7/1000




step: 8/1000




step: 9/1000




step: 10/1000




step: 11/1000




step: 12/1000




step: 13/1000




step: 14/1000




step: 15/1000




step: 16/1000




step: 17/1000




step: 18/1000




step: 19/1000




step: 20/1000




step: 21/1000




step: 22/1000




step: 23/1000




step: 24/1000




step: 25/1000




step: 26/1000




step: 27/1000




step: 28/1000




step: 29/1000




step: 30/1000




step: 31/1000




step: 32/1000




step: 33/1000




step: 34/1000




step: 35/1000




step: 36/1000




step: 37/1000




step: 38/1000




step: 39/1000




step: 40/1000




step: 41/1000




step: 42/1000




step: 43/1000




step: 44/1000




step: 45/1000




step: 46/1000




step: 47/1000




step: 48/1000




step: 49/1000




step: 50/1000




step: 51/1000




step: 52/1000




step: 53/1000




step: 54/1000




step: 55/1000




step: 56/1000




step: 57/1000




step: 58/1000




step: 59/1000




step: 60/1000




step: 61/1000




step: 62/1000




step: 63/1000




step: 64/1000




step: 65/1000




step: 66/1000




step: 67/1000




step: 68/1000




step: 69/1000




step: 70/1000




step: 71/1000




step: 72/1000




step: 73/1000




step: 74/1000




step: 75/1000




step: 76/1000




step: 77/1000




step: 78/1000




step: 79/1000




step: 80/1000




step: 81/1000




step: 82/1000




step: 83/1000




step: 84/1000




step: 85/1000




step: 86/1000




step: 87/1000




step: 88/1000




step: 89/1000




step: 90/1000




step: 91/1000




step: 92/1000




step: 93/1000




step: 94/1000




step: 95/1000




step: 96/1000




step: 97/1000




step: 98/1000




step: 99/1000




step: 100/1000




step: 101/1000




step: 102/1000




step: 103/1000




step: 104/1000




step: 105/1000




step: 106/1000




step: 107/1000




step: 108/1000




step: 109/1000




step: 110/1000




step: 111/1000




step: 112/1000




step: 113/1000




step: 114/1000




step: 115/1000




step: 116/1000




step: 117/1000




step: 118/1000




step: 119/1000




step: 120/1000




step: 121/1000




step: 122/1000




step: 123/1000




step: 124/1000




step: 125/1000




step: 126/1000




step: 127/1000




step: 128/1000




step: 129/1000




step: 130/1000




step: 131/1000




step: 132/1000




step: 133/1000




step: 134/1000




step: 135/1000




step: 136/1000




step: 137/1000




step: 138/1000




step: 139/1000




step: 140/1000




step: 141/1000




step: 142/1000




step: 143/1000




step: 144/1000




step: 145/1000




step: 146/1000




step: 147/1000




step: 148/1000




step: 149/1000




step: 150/1000




step: 151/1000




step: 152/1000




step: 153/1000




step: 154/1000




step: 155/1000




step: 156/1000




step: 157/1000




step: 158/1000




step: 159/1000




step: 160/1000




step: 161/1000




step: 162/1000




step: 163/1000




step: 164/1000




step: 165/1000




step: 166/1000




step: 167/1000




step: 168/1000




step: 169/1000




step: 170/1000




step: 171/1000




step: 172/1000




step: 173/1000




step: 174/1000




step: 175/1000




step: 176/1000




step: 177/1000




step: 178/1000




step: 179/1000




step: 180/1000




step: 181/1000




step: 182/1000




step: 183/1000




step: 184/1000




step: 185/1000




step: 186/1000




step: 187/1000




step: 188/1000




step: 189/1000




step: 190/1000




step: 191/1000




step: 192/1000




step: 193/1000




step: 194/1000




step: 195/1000




step: 196/1000




step: 197/1000




step: 198/1000
step: 199/1000




step: 200/1000




step: 201/1000




step: 202/1000




step: 203/1000




step: 204/1000




step: 205/1000




step: 206/1000




step: 207/1000




step: 208/1000




step: 209/1000




step: 210/1000




step: 211/1000




step: 212/1000




step: 213/1000




step: 214/1000




step: 215/1000




step: 216/1000




step: 217/1000




step: 218/1000




step: 219/1000




step: 220/1000




step: 221/1000




step: 222/1000




step: 223/1000




step: 224/1000




step: 225/1000




step: 226/1000




step: 227/1000




step: 228/1000




step: 229/1000




step: 230/1000




step: 231/1000




step: 232/1000




step: 233/1000




step: 234/1000




step: 235/1000




step: 236/1000




step: 237/1000




step: 238/1000




step: 239/1000




step: 240/1000




step: 241/1000




step: 242/1000




step: 243/1000




step: 244/1000




step: 245/1000




step: 246/1000




step: 247/1000




step: 248/1000




step: 249/1000




step: 250/1000




step: 251/1000




step: 252/1000




step: 253/1000




step: 254/1000




step: 255/1000




step: 256/1000




step: 257/1000




step: 258/1000




step: 259/1000




step: 260/1000




step: 261/1000




step: 262/1000




step: 263/1000




step: 264/1000




step: 265/1000




step: 266/1000




step: 267/1000




step: 268/1000




step: 269/1000




step: 270/1000




step: 271/1000




step: 272/1000




step: 273/1000




step: 274/1000




step: 275/1000




step: 276/1000




step: 277/1000




step: 278/1000




step: 279/1000




step: 280/1000




step: 281/1000




step: 282/1000




step: 283/1000




step: 284/1000




step: 285/1000




step: 286/1000




step: 287/1000




step: 288/1000




step: 289/1000




step: 290/1000




step: 291/1000




step: 292/1000




step: 293/1000




step: 294/1000




step: 295/1000




step: 296/1000




step: 297/1000




step: 298/1000




step: 299/1000




step: 300/1000




step: 301/1000




step: 302/1000




step: 303/1000




step: 304/1000




step: 305/1000




step: 306/1000




step: 307/1000




step: 308/1000




step: 309/1000




step: 310/1000




step: 311/1000




step: 312/1000




step: 313/1000




step: 314/1000




step: 315/1000




step: 316/1000




step: 317/1000




step: 318/1000




step: 319/1000




step: 320/1000




step: 321/1000




step: 322/1000




step: 323/1000




step: 324/1000




step: 325/1000




step: 326/1000




step: 327/1000




step: 328/1000




step: 329/1000




step: 330/1000




step: 331/1000




step: 332/1000




step: 333/1000




step: 334/1000




step: 335/1000




step: 336/1000




step: 337/1000




step: 338/1000




step: 339/1000




step: 340/1000




step: 341/1000




step: 342/1000




step: 343/1000




step: 344/1000




step: 345/1000




step: 346/1000




step: 347/1000




step: 348/1000




step: 349/1000




step: 350/1000




step: 351/1000




step: 352/1000




step: 353/1000




step: 354/1000




step: 355/1000




step: 356/1000




step: 357/1000




step: 358/1000




step: 359/1000




step: 360/1000




step: 361/1000




step: 362/1000




step: 363/1000




step: 364/1000




step: 365/1000




step: 366/1000




step: 367/1000




step: 368/1000




step: 369/1000




step: 370/1000




step: 371/1000




step: 372/1000




step: 373/1000




step: 374/1000




step: 375/1000




step: 376/1000




step: 377/1000




step: 378/1000




step: 379/1000




step: 380/1000




step: 381/1000




step: 382/1000




step: 383/1000




step: 384/1000




step: 385/1000




step: 386/1000




step: 387/1000




step: 388/1000




step: 389/1000




step: 390/1000




step: 391/1000




step: 392/1000




step: 393/1000




step: 394/1000




step: 395/1000




step: 396/1000




step: 397/1000




step: 398/1000
step: 399/1000




step: 400/1000




step: 401/1000




step: 402/1000




step: 403/1000




step: 404/1000




step: 405/1000




step: 406/1000




step: 407/1000




step: 408/1000




step: 409/1000




step: 410/1000




step: 411/1000




step: 412/1000




step: 413/1000




step: 414/1000




step: 415/1000




step: 416/1000




step: 417/1000




step: 418/1000




step: 419/1000




step: 420/1000




step: 421/1000




step: 422/1000




step: 423/1000




step: 424/1000




step: 425/1000




step: 426/1000




step: 427/1000




step: 428/1000




step: 429/1000




step: 430/1000




step: 431/1000




step: 432/1000




step: 433/1000




step: 434/1000




step: 435/1000




step: 436/1000




step: 437/1000




step: 438/1000




step: 439/1000




step: 440/1000




step: 441/1000




step: 442/1000




step: 443/1000




step: 444/1000




step: 445/1000




step: 446/1000




step: 447/1000




step: 448/1000




step: 449/1000




step: 450/1000




step: 451/1000




step: 452/1000




step: 453/1000




step: 454/1000




step: 455/1000




step: 456/1000




step: 457/1000




step: 458/1000




step: 459/1000




step: 460/1000




step: 461/1000




step: 462/1000




step: 463/1000




step: 464/1000




step: 465/1000




step: 466/1000




step: 467/1000




step: 468/1000




step: 469/1000




step: 470/1000




step: 471/1000




step: 472/1000




step: 473/1000




step: 474/1000




step: 475/1000




step: 476/1000




step: 477/1000




step: 478/1000




step: 479/1000




step: 480/1000




step: 481/1000




step: 482/1000




step: 483/1000




step: 484/1000




step: 485/1000




step: 486/1000




step: 487/1000




step: 488/1000




step: 489/1000




step: 490/1000




step: 491/1000




step: 492/1000




step: 493/1000




step: 494/1000




step: 495/1000




step: 496/1000




step: 497/1000




step: 498/1000




step: 499/1000




step: 500/1000




step: 501/1000




step: 502/1000




step: 503/1000




step: 504/1000




step: 505/1000




step: 506/1000




step: 507/1000




step: 508/1000




step: 509/1000




step: 510/1000




step: 511/1000




step: 512/1000




step: 513/1000




step: 514/1000




step: 515/1000




step: 516/1000




step: 517/1000




step: 518/1000




step: 519/1000




step: 520/1000




step: 521/1000




step: 522/1000




step: 523/1000




step: 524/1000




step: 525/1000




step: 526/1000




step: 527/1000




step: 528/1000




step: 529/1000




step: 530/1000




step: 531/1000




step: 532/1000




step: 533/1000




step: 534/1000




step: 535/1000




step: 536/1000




step: 537/1000




step: 538/1000




step: 539/1000




step: 540/1000




step: 541/1000




step: 542/1000




step: 543/1000




step: 544/1000




step: 545/1000




step: 546/1000




step: 547/1000




step: 548/1000




step: 549/1000




step: 550/1000




step: 551/1000




step: 552/1000




step: 553/1000




step: 554/1000




step: 555/1000




step: 556/1000




step: 557/1000




step: 558/1000




step: 559/1000




step: 560/1000




step: 561/1000




step: 562/1000




step: 563/1000




step: 564/1000




step: 565/1000




step: 566/1000




step: 567/1000




step: 568/1000




step: 569/1000




step: 570/1000




step: 571/1000




step: 572/1000




step: 573/1000




step: 574/1000




step: 575/1000




step: 576/1000




step: 577/1000




step: 578/1000




step: 579/1000




step: 580/1000




step: 581/1000




step: 582/1000




step: 583/1000




step: 584/1000




step: 585/1000




step: 586/1000




step: 587/1000




step: 588/1000




step: 589/1000




step: 590/1000




step: 591/1000




step: 592/1000




step: 593/1000




step: 594/1000




step: 595/1000




step: 596/1000




step: 597/1000




step: 598/1000
step: 599/1000




step: 600/1000




step: 601/1000




step: 602/1000




step: 603/1000




step: 604/1000




step: 605/1000




step: 606/1000




step: 607/1000




step: 608/1000




step: 609/1000




step: 610/1000




step: 611/1000




step: 612/1000




step: 613/1000




step: 614/1000




step: 615/1000




step: 616/1000




step: 617/1000




step: 618/1000




step: 619/1000




step: 620/1000




step: 621/1000




step: 622/1000




step: 623/1000




step: 624/1000




step: 625/1000




step: 626/1000




step: 627/1000




step: 628/1000




step: 629/1000




step: 630/1000




step: 631/1000




step: 632/1000




step: 633/1000




step: 634/1000




step: 635/1000




step: 636/1000




step: 637/1000




step: 638/1000




step: 639/1000




step: 640/1000




step: 641/1000




step: 642/1000




step: 643/1000




step: 644/1000




step: 645/1000




step: 646/1000




step: 647/1000




step: 648/1000




step: 649/1000




step: 650/1000




step: 651/1000




step: 652/1000




step: 653/1000




step: 654/1000




step: 655/1000




step: 656/1000




step: 657/1000




step: 658/1000




step: 659/1000




step: 660/1000




step: 661/1000




step: 662/1000




step: 663/1000




step: 664/1000




step: 665/1000




step: 666/1000




step: 667/1000




step: 668/1000




step: 669/1000




step: 670/1000




step: 671/1000




step: 672/1000




step: 673/1000




step: 674/1000




step: 675/1000




step: 676/1000




step: 677/1000




step: 678/1000




step: 679/1000




step: 680/1000




step: 681/1000




step: 682/1000




step: 683/1000




step: 684/1000




step: 685/1000




step: 686/1000




step: 687/1000




step: 688/1000




step: 689/1000




step: 690/1000




step: 691/1000




step: 692/1000




step: 693/1000




step: 694/1000




step: 695/1000




step: 696/1000




step: 697/1000




step: 698/1000




step: 699/1000




step: 700/1000




step: 701/1000




step: 702/1000




step: 703/1000




step: 704/1000




step: 705/1000




step: 706/1000




step: 707/1000




step: 708/1000




step: 709/1000




step: 710/1000




step: 711/1000




step: 712/1000




step: 713/1000




step: 714/1000




step: 715/1000




step: 716/1000




step: 717/1000




step: 718/1000




step: 719/1000




step: 720/1000




step: 721/1000




step: 722/1000




step: 723/1000




step: 724/1000




step: 725/1000




step: 726/1000




step: 727/1000




step: 728/1000




step: 729/1000




step: 730/1000




step: 731/1000




step: 732/1000




step: 733/1000




step: 734/1000




step: 735/1000




step: 736/1000




step: 737/1000




step: 738/1000




step: 739/1000




step: 740/1000




step: 741/1000




step: 742/1000




step: 743/1000




step: 744/1000




step: 745/1000




step: 746/1000




step: 747/1000




step: 748/1000




step: 749/1000




step: 750/1000




step: 751/1000




step: 752/1000




step: 753/1000




step: 754/1000




step: 755/1000




step: 756/1000




step: 757/1000




step: 758/1000




step: 759/1000




step: 760/1000




step: 761/1000




step: 762/1000




step: 763/1000




step: 764/1000




step: 765/1000




step: 766/1000




step: 767/1000




step: 768/1000




step: 769/1000




step: 770/1000




step: 771/1000




step: 772/1000




step: 773/1000




step: 774/1000




step: 775/1000




step: 776/1000




step: 777/1000




step: 778/1000




step: 779/1000




step: 780/1000




step: 781/1000




step: 782/1000




step: 783/1000




step: 784/1000




step: 785/1000




step: 786/1000




step: 787/1000




step: 788/1000




step: 789/1000




step: 790/1000




step: 791/1000




step: 792/1000




step: 793/1000




step: 794/1000




step: 795/1000




step: 796/1000




step: 797/1000




step: 798/1000
step: 799/1000




step: 800/1000




step: 801/1000




step: 802/1000




step: 803/1000




step: 804/1000




step: 805/1000




step: 806/1000




step: 807/1000




step: 808/1000




step: 809/1000




step: 810/1000




step: 811/1000




step: 812/1000




step: 813/1000




step: 814/1000




step: 815/1000




step: 816/1000




step: 817/1000




step: 818/1000




step: 819/1000




step: 820/1000




step: 821/1000




step: 822/1000




step: 823/1000




step: 824/1000




step: 825/1000




step: 826/1000




step: 827/1000




step: 828/1000




step: 829/1000




step: 830/1000




step: 831/1000




step: 832/1000




step: 833/1000




step: 834/1000




step: 835/1000




step: 836/1000




step: 837/1000




step: 838/1000




step: 839/1000




step: 840/1000




step: 841/1000




step: 842/1000




step: 843/1000




step: 844/1000




step: 845/1000




step: 846/1000




step: 847/1000




step: 848/1000




step: 849/1000




step: 850/1000




step: 851/1000




step: 852/1000




step: 853/1000




step: 854/1000




step: 855/1000




step: 856/1000




step: 857/1000




step: 858/1000




step: 859/1000




step: 860/1000




step: 861/1000




step: 862/1000




step: 863/1000




step: 864/1000




step: 865/1000




step: 866/1000




step: 867/1000




step: 868/1000




step: 869/1000




step: 870/1000




step: 871/1000




step: 872/1000




step: 873/1000




step: 874/1000




step: 875/1000




step: 876/1000




step: 877/1000




step: 878/1000




step: 879/1000




step: 880/1000




step: 881/1000




step: 882/1000




step: 883/1000




step: 884/1000




step: 885/1000




step: 886/1000




step: 887/1000




step: 888/1000




step: 889/1000




step: 890/1000




step: 891/1000




step: 892/1000




step: 893/1000




step: 894/1000




step: 895/1000




step: 896/1000




step: 897/1000




step: 898/1000




step: 899/1000




step: 900/1000




step: 901/1000




step: 902/1000




step: 903/1000




step: 904/1000




step: 905/1000




step: 906/1000




step: 907/1000




step: 908/1000




step: 909/1000




step: 910/1000




step: 911/1000




step: 912/1000




step: 913/1000




step: 914/1000




step: 915/1000




step: 916/1000




step: 917/1000




step: 918/1000




step: 919/1000




step: 920/1000




step: 921/1000




step: 922/1000




step: 923/1000




step: 924/1000




step: 925/1000




step: 926/1000




step: 927/1000




step: 928/1000




step: 929/1000




step: 930/1000




step: 931/1000




step: 932/1000




step: 933/1000




step: 934/1000




step: 935/1000




step: 936/1000




step: 937/1000




step: 938/1000




step: 939/1000




step: 940/1000




step: 941/1000




step: 942/1000




step: 943/1000




step: 944/1000




step: 945/1000




step: 946/1000




step: 947/1000




step: 948/1000




step: 949/1000




step: 950/1000




step: 951/1000




step: 952/1000




step: 953/1000




step: 954/1000




step: 955/1000




step: 956/1000




step: 957/1000




step: 958/1000




step: 959/1000




step: 960/1000




step: 961/1000




step: 962/1000




step: 963/1000




step: 964/1000




step: 965/1000




step: 966/1000




step: 967/1000




step: 968/1000




step: 969/1000




step: 970/1000




step: 971/1000




step: 972/1000




step: 973/1000




step: 974/1000




step: 975/1000




step: 976/1000




step: 977/1000




step: 978/1000




step: 979/1000




step: 980/1000




step: 981/1000




step: 982/1000




step: 983/1000




step: 984/1000




step: 985/1000




step: 986/1000




step: 987/1000




step: 988/1000




step: 989/1000




step: 990/1000




step: 991/1000




step: 992/1000




step: 993/1000




step: 994/1000




step: 995/1000




step: 996/1000




step: 997/1000




step: 998/1000
step: 999/1000


training: 100%|██████████| 1000/1000 [04:09<00:00,  4.00it/s]
initializing models: 100%|██████████| 1/1 [04:17<00:00, 257.30s/it]


In [8]:
# SAVE init_model_state_pool
# 定义保存文件的路径
filename = "./tmpdata/init_model_state_pool_sd1-5.pth"

# 使用pickle将数据保存到文件
with open(filename, 'wb') as file:
    pickle.dump(init_model_state_pool, file)

# 读取保存的文件
# with open(filename, 'rb') as f:
#     init_model_state_pool = pickle.load(f)

In [8]:
len(init_model_state_pool)

2

#### 优化扰动

#### 使用求解器的结果

#### 使用判别器的结果

In [20]:
# 该版本无显存溢出问题，但是需要21G显存
device = torch.device('cuda')
weight_dtype = torch.bfloat16
if args.mixed_precision == "fp32":
    weight_dtype = torch.float32
elif args.mixed_precision == "fp16":
    weight_dtype = torch.float16
elif args.mixed_precision == "bf16":
    weight_dtype = torch.bfloat16
class PAN_attacker():
    def __init__(self, lambda_D=0.1, lambda_S=10, omiga=0.5, alpha=1/255, k=2, radius=11, x_range=[0,255], steps=1, mode = "D", use_val = "last", no_attack = False):
        self.lambda_D = lambda_D
        self.lambda_S = lambda_S
        self.omiga = omiga
        self.alpha = alpha
        self.k = k
        self.radius = radius
        self.random_start = args.attack_pgd_random_start
        self.weight_dtype = torch.bfloat16  # 默认类型
        self.left = x_range[0]
        self.right = x_range[1]
        self.norm_type = 'l-infty'
        self.steps = steps
        self.mode = mode
        self.use_val = use_val
        self.noattack = no_attack
        if args.mixed_precision == "fp32":
            self.weight_dtype = torch.float32
        elif args.mixed_precision == "fp16":
            self.weight_dtype = torch.float16
        elif args.mixed_precision == "bf16":
            self.weight_dtype = torch.bfloat16
        
    def attack(self, f, perturbed_data, ori_image, vae, tokenizer, noise_scheduler):
        if self.noattack:
            print("defender no need to defend")
            return perturbed_data, 0


        f = [f[0].to(device, dtype=self.weight_dtype), f[1].to(device, dtype=self.weight_dtype)]
        vae.to(device, dtype=self.weight_dtype)
        perturbed_data = perturbed_data.to(device)
        # ori_image = deepcopy(perturbed_data).to(device)
        ori_image = ori_image.to(device)
        # batch_size = ori_image.size(0)
        # random start部分操作逻辑未设计
        if self.random_start:
            r=self.radius
            initial_pertubations = torch.zeros_like(ori_image).uniform_(-r, r).to(device)
            adv_image = perturbed_data+initial_pertubations
            perturbed_data = adv_image - self._clip_(adv_image, ori_image, mode="D")
        else:
            initial_pertubations = torch.zeros_like(ori_image).to(device)
        # 此轮攻击的初始扰动都是0
        pertubation_data_D = deepcopy(perturbed_data)
        pertubation_data_S = deepcopy(perturbed_data)
        best_loss_S = float('inf')
        best_loss_D = float('inf')
        best_pertubation_data_S = deepcopy(perturbed_data)
        best_pertubation_data_D = deepcopy(perturbed_data)

        for i in range(self.steps):
            # print(f'step {i} :per_s is {pertubations_S[0]}')
            # 更新扰动D
            pertubation_data_D, loss_D = self.update_pertubation_data_D(f, pertubation_data_D, ori_image, vae, tokenizer, noise_scheduler)
            if loss_D < best_loss_D:
                best_loss_D = loss_D
                # if mode == "D":
                    # print(f'pertubation_D, max val is {self.get_Linfty_norm(pertubations_D)}')
                    # print(f"find a better pertubation , max val is {self.get_Linfty_norm( pertubations_D.to('cpu') + perturbed_data.to('cpu') - ori_image.to('cpu') )}")
                best_pertubation_data_D = deepcopy(pertubation_data_D)
            # 更新扰动S
            pertubation_data_S, loss_S = self.update_pertubation_S(f, pertubation_data_S, pertubation_data_D, ori_image, vae, tokenizer, noise_scheduler)
            print(f'epoch: {i}, loss_S: {loss_S:.4f}, loss_D: {loss_D: .4f}')
            if loss_S < best_loss_S:
                best_loss_S = loss_S
                # if mode == "S":
                    # print(f"find a better pertubation , max val is {self.get_Linfty_norm(pertubations_S.to('cpu') + perturbed_data.to('cpu') - ori_image.to('cpu'))}")
                best_pertubation_data_S = deepcopy(pertubation_data_S)
        
        assert self.mode in ["S", "D"]
        assert self.use_val in ["best", "last"]

        if self.mode == "S":
            use_pertubation_data = pertubation_data_S if self.use_val == "last" else best_pertubation_data_S
            loss = loss_S if self.use_val == "last" else best_loss_S
        elif self.mode == "D":
            use_pertubation_data = pertubation_data_D if self.use_val == "last" else best_pertubation_data_D
            loss = loss_D if self.use_val == "last" else best_loss_D
        
        # print(f"find a better pertubation_{mode} , max val is {self.get_Linfty_norm(use_pertubations)}")
        # print(f"use_per is :{use_pertubations[2]}")
        return use_pertubation_data, loss

    def certi(self, models, adv_x, vae, noise_scheduler, input_ids, weight_dtype=None, target_tensor=None):
        unet, text_encoder = models
        unet.zero_grad()
        text_encoder.zero_grad()
        device = torch.device("cuda")

        adv_latens = vae.encode(adv_x.to(device, dtype=weight_dtype)).latent_dist.sample()
        adv_latens = adv_latens * vae.config.scaling_factor

        noise = torch.randn_like(adv_latens)
        bsz = adv_latens.shape[0]
        timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=adv_latens.device)
        timesteps = timesteps.long()

        noisy_latents = noise_scheduler.add_noise(adv_latens, noise, timesteps)
        encoder_hidden_states = text_encoder(input_ids.to(device))[0]
        model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample

        if noise_scheduler.config.prediction_type == "epsilon":
            target = noise
        elif noise_scheduler.config.prediction_type == "v_prediction":
            target = noise_scheduler.get_velocity(adv_latens, noise, timesteps)
        else:
            raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

        loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")

        if target_tensor is not None:
            timesteps = timesteps.to(device)
            noisy_latents = noisy_latents.to(device)
            xtm1_pred = torch.cat(
                [
                    noise_scheduler.step(
                        model_pred[idx: idx + 1],
                        timesteps[idx: idx + 1],
                        noisy_latents[idx: idx + 1],
                    ).prev_sample
                    for idx in range(len(model_pred))
                ]
            )
            xtm1_target = noise_scheduler.add_noise(target_tensor, noise.to(device), (timesteps - 1).to(device))
            loss = loss - F.mse_loss(xtm1_pred, xtm1_target)
        return loss

    def get_loss_D(self, f, adv_image, ori_image, vae, tokenizer, noise_scheduler):
        input_ids = tokenizer(
            args.instance_prompt,
            truncation=True,
            padding="max_length",
            max_length=tokenizer.model_max_length,
            return_tensors="pt",
        ).input_ids.repeat(len(adv_image), 1)

        loss_P = self.certi(f, adv_image, vae, noise_scheduler, input_ids, weight_dtype=self.weight_dtype)
        # 取最大是否合适
        pertubation_linf = torch.max(self.get_Linfty_norm(adv_image-ori_image))
        loss = - loss_P + (self.lambda_D * torch.abs(pertubation_linf)**self.k)
        return loss

    def update_pertubation_data_D(self, f, adv_image, ori_image, vae, tokenizer, noise_scheduler):
        adv_image.requires_grad = True
        loss = self.get_loss_D(f, adv_image, ori_image, vae, tokenizer, noise_scheduler)
        loss.backward()
        grad_ml_alpha = self.alpha * adv_image.grad.sign()
        adv_image_new = adv_image - grad_ml_alpha
        adv_image_new = self._clip_(adv_image_new, ori_image, mode = "D")
        adv_image_new = adv_image_new.detach()
        torch.cuda.empty_cache()
        return adv_image_new, loss.item()

    def update_pertubation_S(self, f, pertubation_data_S, pertubation_data_D, ori_image, vae, tokenizer, noise_scheduler):
        # print(f'old pertubation_S: {pertubation_S[2]}')
        pertubation_data_S.requires_grad = True
        adv_image_S = pertubation_data_S
        adv_image_D = pertubation_data_D

        input_ids = tokenizer(
            args.instance_prompt,
            truncation=True,
            padding="max_length",
            max_length=tokenizer.model_max_length,
            return_tensors="pt",
        ).input_ids.repeat(len(adv_image_S), 1)

        loss_P_S = self.certi(f, adv_image_S, vae, noise_scheduler, input_ids, weight_dtype=self.weight_dtype)
        loss_P_D = self.certi(f, adv_image_D, vae, noise_scheduler, input_ids, weight_dtype=self.weight_dtype)

        pertubation_linf_S = torch.max(self.get_Linfty_norm(adv_image_S-ori_image))
        loss = - loss_P_S + self.lambda_S * (torch.abs(pertubation_linf_S)**self.k) + self.omiga * (torch.abs(loss_P_S - loss_P_D)**self.k)
        loss.backward()

        # print(f'grad:{self.alpha * pertubation_S.grad.sign()[0]}')
        # print(f'now pertubation_S: {pertubation_S[0]}')
        grad_ml_alpha = self.alpha * adv_image_S.grad.sign()
        # print(f'old pertubation_S: {pertubation_S[2]}')
        # print(f' grad_ml_alpha: {grad_ml_alpha[2]}')
        adv_image_S_new = adv_image_S - grad_ml_alpha
        # print(f'inner:{self.get_Linfty_norm(adv_image_S - grad_ml_alpha-ori_image)}')
        # 裁剪到0～255之间,并确保扰动没有超出范围
        adv_image_S_new = self._clip_(adv_image_S_new, ori_image, mode='S')
        # print(f'new pertubation_S: {pertubation_S_new[2]}')
        adv_image_S_new = adv_image_S_new.detach()
        # print(f'new pertubation_S: {pertubation_S[2]}')
        torch.cuda.empty_cache()
        return adv_image_S_new, loss.item()

    def get_Linfty_norm(self, images):
        abs_images = torch.abs(images)
        max_pixels_per_image, _ = torch.max(abs_images, dim=3)
        max_pixels_per_image, _ = torch.max(max_pixels_per_image, dim=2)
        Linfty_norm, _ = torch.max(max_pixels_per_image, dim=1)
        return Linfty_norm

    def _clip_(self, adv_x, x, mode):
        adv_x = adv_x - x
        if self.norm_type == 'l-infty':
            if mode == 'S':
                adv_x.clamp_(-self.radius, self.radius)
        else:
            raise NotImplementedError
        adv_x = adv_x + x
        adv_x.clamp_(self.left, self.right)
        return adv_x


# my_attacker = PAN_attacker(lambda_D = 0.0001, lambda_S = 0.05, alpha = 0.2, omiga = 0.5, k = 2, x_range = [0,255], radius = 11, steps=1, mode = "D", use_val = "last")
my_attacker = PAN_attacker(lambda_D = 0.0001, lambda_S = 0.05, alpha = 1, omiga = 0.5, k = 2, x_range = [0,255], radius = 6, steps=6, mode = "S", use_val = "last")

In [None]:
raise ValueError('wait')

In [21]:
perturbed_data = deepcopy(original_data)

In [22]:
args.total_trail_num = 4
args.total_train_steps = 10
args.interval = 2

In [None]:
# 提取保存的中间状态的step数据（0,199，399...）        
steps_list = list(init_model_state_pool[0].keys())
# 进度条，总train_few_step调用的次数
pbar = tqdm(total=args.total_trail_num * num_models * (args.interval // args.advance_steps) * len(steps_list), desc="meta poison with model ensemble")
cnt=0
# learning perturbation over the ensemble of models
# 在多个模型集合上进行扰动优化
# 多次实验
for _ in range(args.total_trail_num):
    # 针对每一个模型
    for model_i in range(num_models):
        # 确定关键组件
        text_encoder, unet, tokenizer, noise_scheduler, vae = MODEL_BANKS[model_i]
        # 对于每一个中间状态step
        for split_step in steps_list: 
            # 加载unet和文本编码器的中间状态参数
            unet.load_state_dict(init_model_state_pool[model_i][split_step]["unet"])
            text_encoder.load_state_dict(init_model_state_pool[model_i][split_step]["text_encoder"])
            f = [unet, text_encoder]
            # 每advance_steps步进行一次防御优化
            for j in range(args.interval // args.advance_steps):
                before = deepcopy(perturbed_data)
                perturbed_data,rubust_loss = my_attacker.attack(f, perturbed_data, original_data, vae, tokenizer, noise_scheduler)
                print(my_attacker.get_Linfty_norm(perturbed_data.to('cpu')-before.to('cpu')))
                print(my_attacker.get_Linfty_norm(perturbed_data.to('cpu')-original_data))
                # break
                # perturbed_data,rubust_loss = defender.perturb(f, perturbed_data, original_data, vae, tokenizer, noise_scheduler,)
                # 扰动优化次数更新 +1
                # wandb.log({"defender_rubust_loss_without_MAT": rubust_loss})
                cnt+=1
                
                f = train_few_step(
                    args,
                    f,
                    tokenizer,
                    noise_scheduler,
                    vae,
                    perturbed_data.float(),
                    args.advance_steps,
                    copy_flag = False,
                )
                pbar.update(1)
                # 每1000次扰动优化，保存一次扰动示例图像
                if cnt % 1000 == 0:
                    save_image(perturbed_data, f"{cnt}")
            
            # frequently release the memory due to limited GPU memory, 
            # env with more gpu might consider to remove the following lines for boosting speed
            # 释放资源
            del f 
            torch.cuda.empty_cache()
            # break
            
        del unet, text_encoder, tokenizer, noise_scheduler, vae

        if torch.cuda.is_available():
            torch.cuda.empty_cache() 
        # break
    import gc
    gc.collect()
    torch.cuda.empty_cache()   
    # break   
pbar.close()
# 保存最后的结果
save_image(perturbed_data, "final")

In [24]:
save_image(perturbed_data, "final")

In [29]:
diff = perturbed_data.to('cpu')-original_data

In [None]:
show_images(diff)

In [None]:
show_images(perturbed_data)

In [None]:
my_attacker.get_Linfty_norm(perturbed_data.to('cpu')-original_data)

In [17]:

import matplotlib.pyplot as plt

def show_images(perturbed_data):
    # 检查 perturbed_data 是否是 4D Tensor
    if len(perturbed_data.shape) != 4 or perturbed_data.shape[1] != 3:
        raise ValueError("Input tensor must have shape [4, 3, 512, 512]")

    # 转换到CPU并转换为numpy数组
    images = perturbed_data.cpu().numpy()
    
    # 创建一个4x1的图像显示框架
    fig, axs = plt.subplots(1, 4, figsize=(20, 5))
    
    for i in range(4):
        img = images[i].transpose(1, 2, 0)  # 调整维度为 H x W x C
        axs[i].imshow(img.astype('uint8'))  # 确保数据类型为uint8以显示RGB图片
        axs[i].axis('off')  # 关闭坐标轴显示
    
    plt.show()

In [None]:
(perturbed_data[2]>255).all()

In [None]:
show_images(original_data)

In [None]:
show_images(perturbed_data)

In [97]:
save_image(perturbed_data, "final")