## 在SD21上测试PAN攻击

In [1]:
import random
import wandb
import argparse
import copy
import hashlib
import itertools
import logging
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import sys
repo_path = "/data/home/yekai/github/mypro/MetaCloak"
sys.path.append(repo_path)
from pathlib import Path
import datasets
import diffusers
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
from diffusers.utils.import_utils import is_xformers_available
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from tqdm import tqdm
from transformers import AutoTokenizer, PretrainedConfig
from robust_facecloak.model.db_train import  DreamBoothDatasetFromTensor
from robust_facecloak.model.db_train import import_model_class_from_model_name_or_path
from robust_facecloak.generic.data_utils import PromptDataset, load_data
from robust_facecloak.generic.share_args import share_parse_args

import pickle
from copy import deepcopy

  from .autonotebook import tqdm as notebook_tqdm
2024-10-09 17:43:44.794910: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-10-09 17:43:44.795003: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-10-09 17:43:44.796413: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-10-09 17:43:44.805412: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
class myargs():
    def __init__(self):
        self.learning_rate=5e-7
        self.total_trail_num = 4
        self.instance_prompt="a photo of sks person"
        self.class_data_dir=f"{repo_path}/prior-data/SD21base/class-person"
        self.instance_data_dir_for_adversarial = f"{repo_path}/dataset/VGGFace2-clean/0/set_B"
        self.output_dir = "./tmpdata"
        self.class_prompt="a photo of a person"
        self.total_train_steps = 1000
        self.interval = 200
        self.advance_steps = 2
        self.radius = 11
        self.resolution=512
        self.center_crop=True
        self.with_prior_preservation=True
        self.revision = None
        self.prior_loss_weight = 1.0
        self.train_text_encoder = True
        self.enable_xformers_memory_efficient_attention = True
        self.mixed_precision = "bf16"
        self.attack_pgd_random_start = False
        

args = myargs()

### 首先是模型加载和训练代码

In [3]:
def train_few_step(
    args,
    models,
    tokenizer,
    noise_scheduler,
    vae,
    data_tensor: torch.Tensor,
    num_steps=20,
    step_wise_save=False,
    save_step=100, 
    retain_graph=False,
    task_loss_name = None,
    copy_flag = True
):
    # Load the tokenizer
    if copy_flag:
        unet, text_encoder = copy.deepcopy(models[0]), copy.deepcopy(models[1])
    else:
        unet, text_encoder = models[0], models[1]
    print('copy model done')
    # 绑定unet和文本编码器的参数，共同优化
    params_to_optimize = itertools.chain(unet.parameters(), text_encoder.parameters())

    # 设置优化器，优化目标为unet参数和文本编码器参数
    optimizer = torch.optim.AdamW(
        params_to_optimize,
        lr=args.learning_rate,
        betas=(0.9, 0.999),
        weight_decay=1e-2,
        eps=1e-08,
    )

    train_dataset = DreamBoothDatasetFromTensor(
        data_tensor,
        # A photo of sks person
        args.instance_prompt,
        tokenizer,
        args.class_data_dir,
        args.class_prompt,
        args.resolution,
        args.center_crop,
    )

    weight_dtype = torch.bfloat16
    device = torch.device("cuda")

    # 将关键模型移动到对应设备
    vae.to(device, dtype=weight_dtype)
    text_encoder.to(device, dtype=weight_dtype)
    unet.to(device, dtype=weight_dtype)

    
    step2modelstate={}
        
    pbar = tqdm(total=num_steps, desc="training")
    print("Start training...")
    for step in range(num_steps):
        print(f"step: {step}/{num_steps}")
        # print(calculate_model_hash(text_encoder))
        # 根据设置选择是否保存训练中间过程参数
        if step_wise_save and ((step+1) % save_step == 0 or step == 0):
            # make sure the model state dict is put to cpu
            step2modelstate[step] = {
                "unet": copy.deepcopy(unet.cpu().state_dict()),
                "text_encoder": copy.deepcopy(text_encoder.cpu().state_dict()),
            }
            # move the model back to gpu
            unet.to(device, dtype=weight_dtype); text_encoder.to(device, dtype=weight_dtype)
            
        pbar.update(1)
        # 训练模式
        unet.train()
        text_encoder.train()
        # 循环从训练数据集中取一个样本
        step_data = train_dataset[step % len(train_dataset)]
        # 将样本中的类别图片和实例图片整合并移动到设备上
        # print((step_data["instance_images"]))
        # print((step_data["class_images"]))
        pixel_values = torch.stack([step_data["instance_images"].to(device), step_data["class_images"].to(device)]).to(
            device, dtype=weight_dtype
        )
        # 将样本中的类别提示词和实例提示词整合并移动到设备上
        input_ids = torch.cat([step_data["instance_prompt_ids"], step_data["class_prompt_ids"]], dim=0).to(device)
        # 使用VAE对图像进行编码，并对潜在表示进行后处理
        latents = vae.encode(pixel_values).latent_dist.sample()
        latents = latents * vae.config.scaling_factor

        # Sample noise that we'll add to the latents
        # 向图片编码向量（潜在空间向量表示）添加随机噪声
        noise = torch.randn_like(latents)
        # batch_size
        bsz = latents.shape[0]
        # Sample a random timestep for each image
        # 为每个图片生成一个随机step
        timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
        timesteps = timesteps.long()

        # Add noise to the latents according to the noise magnitude at each timestep
        # (this is the forward diffusion process)
        # 前向过程，得到前向扩散特定时间步后的图片的潜在空间向量
        noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

        # Get the text embedding for conditioning
        # 文本编码向量作为条件信息
        encoder_hidden_states = text_encoder(input_ids)[0]
        
        # Predict the noise residual
        # 模型基于当前的噪声潜在表示（noisy_latents）、时间步（timesteps）和文本条件（encoder_hidden_states），预测噪声残差
        model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample

        # Get the target for loss depending on the prediction type
        # 预测的可以是噪声，也可以是变化速度
        if noise_scheduler.config.prediction_type == "epsilon":
            target = noise
        elif noise_scheduler.config.prediction_type == "v_prediction":
            target = noise_scheduler.get_velocity(latents, noise, timesteps)
        else:
            raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

        # with prior preservation loss
        # 可选是否使用先验保留损失
        if args.with_prior_preservation:
            # 再次分为一半一半，对应之前的stack操作
            model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
            target, target_prior = torch.chunk(target, 2, dim=0)

            # Compute instance loss
            instance_loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")

            # Compute prior loss  确保在原来类别上的生成能力不丢失
            prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")

            # Add the prior loss to the instance loss.
            loss = instance_loss + args.prior_loss_weight * prior_loss

        else:
            # 不使用先验保留损失
            loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
        if task_loss_name is not None:
            wandb.log({f"{task_loss_name}": loss.item()})
        # 反向传播
        loss.backward(retain_graph=retain_graph)
        # 梯度裁剪
        torch.nn.utils.clip_grad_norm_(params_to_optimize, 1.0, error_if_nonfinite=True)
        # 参数优化
        optimizer.step()
        optimizer.zero_grad()

    pbar.close()
    # 返回训练的参数数据
    if step_wise_save:
        return [unet, text_encoder], step2modelstate
    else:     
        return [unet, text_encoder]

# 主要模型的加载
def load_model(args, model_path):
    print(model_path)
    # import correct text encoder class
    text_encoder_cls = import_model_class_from_model_name_or_path(model_path, args.revision)

    # Load scheduler and models
    # 文本编码器加载
    text_encoder = text_encoder_cls.from_pretrained(
        model_path,
        subfolder="text_encoder",
        revision=args.revision,
    )
    # unet加载
    unet = UNet2DConditionModel.from_pretrained(model_path, subfolder="unet", revision=args.revision)
    # tokenizer加载
    tokenizer = AutoTokenizer.from_pretrained(
        model_path,
        subfolder="tokenizer",
        revision=args.revision,
        use_fast=False,
    )
    # 使用DDPM同款调度器
    noise_scheduler = DDPMScheduler.from_pretrained(model_path, subfolder="scheduler")
    # 加载预训练的vae，vae不需要更新参数
    vae = AutoencoderKL.from_pretrained(model_path, subfolder="vae", revision=args.revision)

    vae.requires_grad_(False)

    # 甚至可以不更新文本编码器的参数
    if not args.train_text_encoder:
        text_encoder.requires_grad_(False)

    if args.enable_xformers_memory_efficient_attention:
        print("You selected to used efficient xformers")
        print("Make sure to install the following packages before continue")
        print("pip install triton==2.0.0.dev20221031")
        print("pip install pip install xformers==0.0.17.dev461")

        unet.enable_xformers_memory_efficient_attention()
    # 返回5个关键模型
    return text_encoder, unet, tokenizer, noise_scheduler, vae

def save_image(perturbed_data, id_stamp):
    save_folder = f"{args.output_dir}/noise-ckpt/{id_stamp}"
    os.makedirs(save_folder, exist_ok=True)
    noised_imgs = perturbed_data.detach()
    img_names = [
        str(instance_path).split("/")[-1]
        for instance_path in list(Path(args.instance_data_dir_for_adversarial).iterdir())
    ]
    for img_pixel, img_name in zip(noised_imgs, img_names):
        save_path = os.path.join(save_folder, f"noisy_{img_name}")
        Image.fromarray(
            img_pixel.float().detach().cpu().permute(1, 2, 0).numpy().squeeze().astype(np.uint8)
        ).save(save_path)

#### 加载模型

In [4]:
model_paths = ["/data/home/yekai/github/mypro/MetaCloak/SD/stable-diffusion-v1-5"]
num_models = len(model_paths)

In [5]:
MODEL_BANKS = [load_model(args, path) for path in model_paths]
MODEL_STATEDICTS = [
    {
        "text_encoder": MODEL_BANKS[i][0].state_dict(),
        "unet": MODEL_BANKS[i][1].state_dict(),
    }
    for i in range(num_models)
]

You are using a model of type clip_text_model to instantiate a model of type . This is not supported for all configurations of models and can yield errors.


/data/home/yekai/github/mypro/MetaCloak/SD/stable-diffusion-v1-5
model_path:/data/home/yekai/github/mypro/MetaCloak/SD/stable-diffusion-v1-5
You selected to used efficient xformers
Make sure to install the following packages before continue
pip install triton==2.0.0.dev20221031
pip install pip install xformers==0.0.17.dev461


#### 开始初始的预训练

In [6]:
# 加载原始扰动数据
perturbed_data = load_data(
    args.instance_data_dir_for_adversarial,
    # size=args.resolution,
    # center_crop=args.center_crop,
)
original_data= copy.deepcopy(perturbed_data)


init_model_state_pool = {}

In [7]:
# pbar = tqdm(total=num_models, desc="initializing models")
# # split sub-models
# # 对于每一个模型，都进行一次训练
# for j in range(num_models):
#     init_model_state_pool[j] = {}
#     # 提取关键模块
#     text_encoder, unet, tokenizer, noise_scheduler, vae = MODEL_BANKS[j]
    
#     # 加载unet和text_encoder的模型参数
#     unet.load_state_dict(MODEL_STATEDICTS[j]["unet"])
#     text_encoder.load_state_dict(MODEL_STATEDICTS[j]["text_encoder"])
#     # 打包unet和text_encoder
#     f_ori = [unet, text_encoder]
#     # 得到训练total_train_steps步之后的unet, text_encoder参数以及中间状态参数
#     print("start training model", j)
#     f_ori, step2state_dict = train_few_step(
#             args,
#             f_ori,
#             tokenizer,
#             noise_scheduler,
#             vae,
#             perturbed_data.float(),
#             args.total_train_steps,
#             step_wise_save=True,
#             save_step=args.interval,
#             task_loss_name=None,
#     )  
#     # init_model_state_pool就来保存训练中间状态参数
#     init_model_state_pool[j] = step2state_dict

#     # 释放占用的资源
#     del f_ori, unet, text_encoder, tokenizer, noise_scheduler, vae
#     import gc
#     gc.collect()
#     torch.cuda.empty_cache()
#     pbar.update(1)
# pbar.close()

In [7]:
# SAVE init_model_state_pool
# 定义保存文件的路径
filename = "./tmpdata/init_model_state_pool_sd2-1.pth"

# 使用pickle将数据保存到文件
# with open(filename, 'wb') as file:
#     pickle.dump(init_model_state_pool, file)

# 读取保存的文件
with open(filename, 'rb') as f:
    init_model_state_pool = pickle.load(f)

In [None]:
len(init_model_state_pool)

#### 优化扰动

#### 使用求解器的结果

#### 使用判别器的结果

In [8]:
# 该版本无显存溢出问题，但是需要21G显存
device = torch.device('cuda')
weight_dtype = torch.bfloat16
if args.mixed_precision == "fp32":
    weight_dtype = torch.float32
elif args.mixed_precision == "fp16":
    weight_dtype = torch.float16
elif args.mixed_precision == "bf16":
    weight_dtype = torch.bfloat16
# class PAN_attacker():
#     def __init__(self, lambda_D=0.1, lambda_S=10, omiga=0.5, alpha=1/255, k=2, radius=11, x_range=[0,255], steps=1, mode = "D", use_val = "last", no_attack = False):
#         self.lambda_D = lambda_D
#         self.lambda_S = lambda_S
#         self.omiga = omiga
#         self.alpha = alpha
#         self.k = k
#         self.radius = radius
#         self.random_start = args.attack_pgd_random_start
#         self.weight_dtype = torch.bfloat16  # 默认类型
#         self.left = x_range[0]
#         self.right = x_range[1]
#         self.norm_type = 'l-infty'
#         self.steps = steps
#         self.mode = mode
#         self.use_val = use_val
#         self.noattack = no_attack
#         if args.mixed_precision == "fp32":
#             self.weight_dtype = torch.float32
#         elif args.mixed_precision == "fp16":
#             self.weight_dtype = torch.float16
#         elif args.mixed_precision == "bf16":
#             self.weight_dtype = torch.bfloat16
        
#     def attack(self, f, perturbed_data, ori_image, vae, tokenizer, noise_scheduler):
#         if self.noattack:
#             print("defender no need to defend")
#             return perturbed_data, 0


#         f = [f[0].to(device, dtype=self.weight_dtype), f[1].to(device, dtype=self.weight_dtype)]
#         vae.to(device, dtype=self.weight_dtype)
#         perturbed_data = perturbed_data.to(device)
#         # ori_image = deepcopy(perturbed_data).to(device)
#         ori_image = ori_image.to(device)
#         # batch_size = ori_image.size(0)
#         # random start部分操作逻辑未设计
#         if self.random_start:
#             r=self.radius
#             initial_pertubations = torch.zeros_like(ori_image).uniform_(-r, r).to(device)
#             adv_image = perturbed_data+initial_pertubations
#             perturbed_data = adv_image - self._clip_(adv_image, ori_image, mode="D")
#         else:
#             initial_pertubations = torch.zeros_like(ori_image).to(device)
#         # 此轮攻击的初始扰动都是0
#         pertubation_data_D = deepcopy(perturbed_data)
#         pertubation_data_S = deepcopy(perturbed_data)
#         best_loss_S = float('inf')
#         best_loss_D = float('inf')
#         best_pertubation_data_S = deepcopy(perturbed_data)
#         best_pertubation_data_D = deepcopy(perturbed_data)

#         for i in range(self.steps):
#             # print(f'step {i} :per_s is {pertubations_S[0]}')
#             # 更新扰动D
#             pertubation_data_D, loss_D = self.update_pertubation_data_D(f, pertubation_data_D, ori_image, vae, tokenizer, noise_scheduler)
#             if loss_D < best_loss_D:
#                 best_loss_D = loss_D
#                 # if mode == "D":
#                     # print(f'pertubation_D, max val is {self.get_Linfty_norm(pertubations_D)}')
#                     # print(f"find a better pertubation , max val is {self.get_Linfty_norm( pertubations_D.to('cpu') + perturbed_data.to('cpu') - ori_image.to('cpu') )}")
#                 best_pertubation_data_D = deepcopy(pertubation_data_D)
#             # 更新扰动S
#             pertubation_data_S, loss_S = self.update_pertubation_S(f, pertubation_data_S, pertubation_data_D, ori_image, vae, tokenizer, noise_scheduler)
#             print(f'epoch: {i}, loss_S: {loss_S:.4f}, loss_D: {loss_D: .4f}')
#             if loss_S < best_loss_S:
#                 best_loss_S = loss_S
#                 # if mode == "S":
#                     # print(f"find a better pertubation , max val is {self.get_Linfty_norm(pertubations_S.to('cpu') + perturbed_data.to('cpu') - ori_image.to('cpu'))}")
#                 best_pertubation_data_S = deepcopy(pertubation_data_S)
        
#         assert self.mode in ["S", "D"]
#         assert self.use_val in ["best", "last"]

#         if self.mode == "S":
#             use_pertubation_data = pertubation_data_S if self.use_val == "last" else best_pertubation_data_S
#             loss = loss_S if self.use_val == "last" else best_loss_S
#         elif self.mode == "D":
#             use_pertubation_data = pertubation_data_D if self.use_val == "last" else best_pertubation_data_D
#             loss = loss_D if self.use_val == "last" else best_loss_D
        
#         # print(f"find a better pertubation_{mode} , max val is {self.get_Linfty_norm(use_pertubations)}")
#         # print(f"use_per is :{use_pertubations[2]}")
#         return use_pertubation_data, loss

#     def certi(self, models, adv_x, vae, noise_scheduler, input_ids, weight_dtype=None, target_tensor=None):
#         unet, text_encoder = models
#         unet.zero_grad()
#         text_encoder.zero_grad()
#         device = torch.device("cuda")

#         adv_latens = vae.encode(adv_x.to(device, dtype=weight_dtype)).latent_dist.sample()
#         adv_latens = adv_latens * vae.config.scaling_factor

#         noise = torch.randn_like(adv_latens)
#         bsz = adv_latens.shape[0]
#         timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=adv_latens.device)
#         timesteps = timesteps.long()

#         noisy_latents = noise_scheduler.add_noise(adv_latens, noise, timesteps)
#         encoder_hidden_states = text_encoder(input_ids.to(device))[0]
#         model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample

#         if noise_scheduler.config.prediction_type == "epsilon":
#             target = noise
#         elif noise_scheduler.config.prediction_type == "v_prediction":
#             target = noise_scheduler.get_velocity(adv_latens, noise, timesteps)
#         else:
#             raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

#         loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")

#         if target_tensor is not None:
#             timesteps = timesteps.to(device)
#             noisy_latents = noisy_latents.to(device)
#             xtm1_pred = torch.cat(
#                 [
#                     noise_scheduler.step(
#                         model_pred[idx: idx + 1],
#                         timesteps[idx: idx + 1],
#                         noisy_latents[idx: idx + 1],
#                     ).prev_sample
#                     for idx in range(len(model_pred))
#                 ]
#             )
#             xtm1_target = noise_scheduler.add_noise(target_tensor, noise.to(device), (timesteps - 1).to(device))
#             loss = loss - F.mse_loss(xtm1_pred, xtm1_target)
#         return loss

#     def get_loss_D(self, f, adv_image, ori_image, vae, tokenizer, noise_scheduler):
#         input_ids = tokenizer(
#             args.instance_prompt,
#             truncation=True,
#             padding="max_length",
#             max_length=tokenizer.model_max_length,
#             return_tensors="pt",
#         ).input_ids.repeat(len(adv_image), 1)

#         loss_P = self.certi(f, adv_image, vae, noise_scheduler, input_ids, weight_dtype=self.weight_dtype)
#         # 取最大是否合适
#         pertubation_linf = torch.max(self.get_Linfty_norm(adv_image-ori_image))
#         loss = - loss_P + (self.lambda_D * torch.abs(pertubation_linf)**self.k)
#         return loss

#     def update_pertubation_data_D(self, f, adv_image, ori_image, vae, tokenizer, noise_scheduler):
#         adv_image.requires_grad = True
#         loss = self.get_loss_D(f, adv_image, ori_image, vae, tokenizer, noise_scheduler)
#         loss.backward()
#         grad_ml_alpha = self.alpha * adv_image.grad.sign()
#         adv_image_new = adv_image - grad_ml_alpha
#         adv_image_new = self._clip_(adv_image_new, ori_image, mode = "D")
#         adv_image_new = adv_image_new.detach()
#         torch.cuda.empty_cache()
#         return adv_image_new, loss.item()

#     def update_pertubation_S(self, f, pertubation_data_S, pertubation_data_D, ori_image, vae, tokenizer, noise_scheduler):
#         # print(f'old pertubation_S: {pertubation_S[2]}')
#         pertubation_data_S.requires_grad = True
#         adv_image_S = pertubation_data_S
#         adv_image_D = pertubation_data_D

#         input_ids = tokenizer(
#             args.instance_prompt,
#             truncation=True,
#             padding="max_length",
#             max_length=tokenizer.model_max_length,
#             return_tensors="pt",
#         ).input_ids.repeat(len(adv_image_S), 1)

#         loss_P_S = self.certi(f, adv_image_S, vae, noise_scheduler, input_ids, weight_dtype=self.weight_dtype)
#         loss_P_D = self.certi(f, adv_image_D, vae, noise_scheduler, input_ids, weight_dtype=self.weight_dtype)

#         pertubation_linf_S = torch.max(self.get_Linfty_norm(adv_image_S-ori_image))
#         loss = - loss_P_S + self.lambda_S * (torch.abs(pertubation_linf_S)**self.k) + self.omiga * (torch.abs(loss_P_S - loss_P_D)**self.k)
#         loss.backward()

#         # print(f'grad:{self.alpha * pertubation_S.grad.sign()[0]}')
#         # print(f'now pertubation_S: {pertubation_S[0]}')
#         grad_ml_alpha = self.alpha * adv_image_S.grad.sign()
#         # print(f'old pertubation_S: {pertubation_S[2]}')
#         # print(f' grad_ml_alpha: {grad_ml_alpha[2]}')
#         adv_image_S_new = adv_image_S - grad_ml_alpha
#         # print(f'inner:{self.get_Linfty_norm(adv_image_S - grad_ml_alpha-ori_image)}')
#         # 裁剪到0～255之间,并确保扰动没有超出范围
#         adv_image_S_new = self._clip_(adv_image_S_new, ori_image, mode='S')
#         # print(f'new pertubation_S: {pertubation_S_new[2]}')
#         adv_image_S_new = adv_image_S_new.detach()
#         # print(f'new pertubation_S: {pertubation_S[2]}')
#         torch.cuda.empty_cache()
#         return adv_image_S_new, loss.item()

#     def get_Linfty_norm(self, images):
#         abs_images = torch.abs(images)
#         max_pixels_per_image, _ = torch.max(abs_images, dim=3)
#         max_pixels_per_image, _ = torch.max(max_pixels_per_image, dim=2)
#         Linfty_norm, _ = torch.max(max_pixels_per_image, dim=1)
#         return Linfty_norm

#     def _clip_(self, adv_x, x, mode):
#         adv_x = adv_x - x
#         if self.norm_type == 'l-infty':
#             if mode == 'S':
#                 adv_x.clamp_(-self.radius, self.radius)
#         else:
#             raise NotImplementedError
#         adv_x = adv_x + x
#         adv_x.clamp_(self.left, self.right)
#         return adv_x

from pan_worker import PANAttacker
# my_attacker = PAN_attacker(lambda_D = 0.0001, lambda_S = 0.05, alpha = 0.2, omiga = 0.5, k = 2, x_range = [0,255], radius = 11, steps=1, mode = "D", use_val = "last")
# my_attacker = PAN_attacker(lambda_D = 0.01, lambda_S = 10, alpha = 1, omiga = 0.5, k = 2, x_range = [0,255], radius = 11, steps=6, mode = "S", use_val = "last")
my_attacker = PANAttacker(lambda_D = 0.01, lambda_S = 10, step_size = 1, omiga = 0.5, k = 2, x_range = [0,255], radius = 11, steps=6, mode = "S", use_val = "last",args=args)


In [None]:
raise ValueError('wait')

In [18]:
perturbed_data = deepcopy(original_data)

In [9]:
args.total_trail_num = 2
args.total_train_steps = 10
args.interval = 2

In [None]:
# 提取保存的中间状态的step数据（0,199，399...）        
steps_list = list(init_model_state_pool[0].keys())
# 进度条，总train_few_step调用的次数
pbar = tqdm(total=args.total_trail_num * num_models * (args.interval // args.advance_steps) * len(steps_list), desc="meta poison with model ensemble")
cnt=0
# learning perturbation over the ensemble of models
# 在多个模型集合上进行扰动优化
# 多次实验
for _ in range(args.total_trail_num):
    # 针对每一个模型
    for model_i in range(num_models):
        # 确定关键组件
        text_encoder, unet, tokenizer, noise_scheduler, vae = MODEL_BANKS[model_i]
        # 对于每一个中间状态step
        for split_step in steps_list: 
            # 加载unet和文本编码器的中间状态参数
            unet.load_state_dict(init_model_state_pool[model_i][split_step]["unet"])
            text_encoder.load_state_dict(init_model_state_pool[model_i][split_step]["text_encoder"])
            f = [unet, text_encoder]
            # 每advance_steps步进行一次防御优化
            for j in range(args.interval // args.advance_steps):
                before = deepcopy(perturbed_data)
                perturbed_data,rubust_loss = my_attacker.attack(f, perturbed_data, original_data, vae, tokenizer, noise_scheduler)
                print(my_attacker.get_Linfty_norm(perturbed_data.to('cpu')-before.to('cpu')))
                print(my_attacker.get_Linfty_norm(perturbed_data.to('cpu')-original_data))
                # break
                # perturbed_data,rubust_loss = defender.perturb(f, perturbed_data, original_data, vae, tokenizer, noise_scheduler,)
                # 扰动优化次数更新 +1
                # wandb.log({"defender_rubust_loss_without_MAT": rubust_loss})
                cnt+=1
                
                f = train_few_step(
                    args,
                    f,
                    tokenizer,
                    noise_scheduler,
                    vae,
                    perturbed_data.float(),
                    args.advance_steps,
                    copy_flag = False,
                )
                pbar.update(1)
                # 每1000次扰动优化，保存一次扰动示例图像
                if cnt % 1000 == 0:
                    save_image(perturbed_data, f"{cnt}")
            
            # frequently release the memory due to limited GPU memory, 
            # env with more gpu might consider to remove the following lines for boosting speed
            # 释放资源
            del f 
            torch.cuda.empty_cache()
            # break
            
        del unet, text_encoder, tokenizer, noise_scheduler, vae

        if torch.cuda.is_available():
            torch.cuda.empty_cache() 
        # break
    import gc
    gc.collect()
    torch.cuda.empty_cache()   
    # break   
pbar.close()
# 保存最后的结果
save_image(perturbed_data, "final")

In [11]:
# 提取保存的中间状态的step数据（0,199，399...）        
steps_list = list(init_model_state_pool[0].keys())
# 进度条，总train_few_step调用的次数
pbar = tqdm(total=args.total_trail_num * num_models * (args.interval // args.advance_steps) * len(steps_list), desc="meta poison with model ensemble")
cnt=0
for _ in range(args.total_trail_num):          
            # 针对每一个模型
            for model_i in range(num_models):
                print(f'using model {model_i}')
                # 确定关键组件
                # start_time = time.time()
                text_encoder, unet, tokenizer, noise_scheduler, vae = MODEL_BANKS[model_i]
                # 对于每一个中间状态step
                for split_step in steps_list: 
                    # 加载unet和文本编码器的中间状态参数
                    unet.load_state_dict(init_model_state_pool[model_i][split_step]["unet"])
                    text_encoder.load_state_dict(init_model_state_pool[model_i][split_step]["text_encoder"])
                    f = [unet, text_encoder]
                    # f = [unet.to(device_1), text_encoder.to(device_1)]
                    
                    # 每advance_steps步进行一次防御优化/对于每一组模型参数，进行200/2=100次对抗训练
                    print(f'start {args.interval // args.advance_steps} times of defense optimization in step-{split_step} model')
                    for j in range(args.interval // args.advance_steps):
                        # 更新一次扰动，使得扰动更加强大,后续需要在此处引入随机性（多轮采样优化），并以扰动的平均值作为后续的扰动
                        # vkeilo add it
                        mean_delta = perturbed_data.clone().detach()
                        for k in range(1):
                            # print(f'sample delta {k}/{args.sampling_times_delta} times')
                            perturbed_data,rubust_loss = my_attacker.attack(f, perturbed_data, original_data, vae, tokenizer, noise_scheduler,)
                            wandb.log({"perturbedloss": rubust_loss})
                            # 此处引入随机梯度朗之万动力学
                            mean_delta = args.beta_s * mean_delta + (1 - args.beta_s) * perturbed_data
                        mean_delta.detach()
                        perturbed_data = mean_delta
                        # print(f"max pixel change:{find_max_pixel_change(perturbed_data, original_data)}")
                        # f[0] = f[0].to(device_0)
                        # f[1] = f[1].to(device_0)
                        # perturbed_data = defender.perturb(f, perturbed_data, original_data, vae, tokenizer, noise_scheduler)
                        
                        # 扰动优化次数更新 +1
                        cnt+=1
                        # 在新的扰动数据下，训练advance_steps步，后续需要在此处引入随机性（多轮采样优化参数），并以参数的平均值作为模型的参数
                        back_parameters_list = [f[0].state_dict(),
                                                f[1].state_dict()]

                        mean_theta_list = [f[0].state_dict(),
                                        f[1].state_dict()]
                        
                        # print(f'start {args.sampling_times_theta} times of theta sampling')
                        for k in range(1):
                            # print(f'sample theta {k}/{args.sampling_times_theta} times')
                            f = train_few_step(
                                args,
                                f,
                                tokenizer,
                                noise_scheduler,
                                vae,
                                perturbed_data.float(),
                                args.advance_steps,
                                # device = device_1
                                dpcopy = False,
                                task_loss_name='model_theta_loss',
                            )
                            torch.cuda.empty_cache()
                            for model_index, model in enumerate(f):
                                # print(f"\nbefore culcu, GPU: {gpu.name}, Free Memory: {gpu.memoryFree / 1024:.2f} GB")
                                for name, p in model.named_parameters():
                                    # 先尝试固定学习率的（因为迭代次数暂未确定）
                                    # lr_now = lr_scheduler.get_last_lr()[0]
                                    # 参数采样,引入随机性
                                    # 模型参数也使用指数平均
                                    # mean_theta_list[model_index][name] = args.beta_s * mean_theta_list[model_index][name] + (1 - args.beta_s) * p.data.to('cpu')
                                    # mean_theta_list[model_index][name] = args.beta_s * mean_theta_list[model_index][name] + (1 - args.beta_s) * p.data
                                    mean_theta_list[model_index][name].mul_(args.beta_s).add_((1 - args.beta_s) * p.data)
                                # print(f"\nafter calcu params, GPU: {gpu.name}, Free Memory: {gpu.memoryFree / 1024:.2f} GB")
                                # torch.cuda.empty_cache()
                        # lr_scheduler.step()
                        # 对于模型的unet和文本编码器，分别更新参数
                        for back_parameters, mean_theta in zip(back_parameters_list,mean_theta_list):
                            for name in back_parameters:
                                back_parameters[name] = args.beta_p * back_parameters[name] + (1 - args.beta_p) * mean_theta[name]
                                # back_parameters[name] = back_parameters[name].float()
                                # back_parameters[name].mul_(args.beta_p).add_((1 - args.beta_p) * mean_theta[name])
                        for index, model in enumerate(f):
                            # model.load_state_dict({k: v.to(device_g) for k, v in back_parameters_list[index].items()})
                            model.load_state_dict(back_parameters_list[index])
                            pass
                        del back_parameters_list
                        del mean_theta_list
                        gc.collect()
                        torch.cuda.empty_cache()
                        # f = train_few_step(
                        #     args,
                        #     f,
                        #     tokenizer,
                        #     noise_scheduler,
                        #     vae,
                        #     perturbed_data.float(),
                        #     args.advance_steps,
                        # )
                        pbar.update(1)
                        # 每1000次扰动优化，保存一次扰动示例图像
                        if cnt % 1000 == 0:
                            save_image(perturbed_data, f"{cnt}")
                    # frequently release the memory due to limited GPU memory, 
                    # env with more gpu might consider to remove the following lines for boosting speed
                    # 释放资源
                    del f 
                    torch.cuda.empty_cache()
                # end_time = time.time()
                # logger.info(f"model {model_i} adversarial training Time cost: {(end_time - start_time) / 60} min")
                del unet, text_encoder, tokenizer, noise_scheduler, vae

                if torch.cuda.is_available():
                    torch.cuda.empty_cache() 

            import gc
            gc.collect()
            torch.cuda.empty_cache()

meta poison with model ensemble:   0%|          | 0/12 [00:00<?, ?it/s]

using model 0
start 1 times of defense optimization in step-0 model
 adv_image_S.grad.sign(): tensor([[[[ 1.,  1.,  1.,  ...,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  ...,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  ...,  1.,  1.,  1.],
          ...,
          [-1., -1.,  1.,  ...,  1.,  1.,  1.],
          [-1., -1.,  1.,  ...,  1.,  1.,  1.],
          [-1., -1., -1.,  ...,  1.,  1.,  1.]],

         [[ 1.,  1.,  1.,  ...,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  ...,  1.,  1.,  1.],
          [-1., -1., -1.,  ...,  1.,  1.,  1.],
          ...,
          [-1., -1., -1.,  ...,  1.,  1.,  1.],
          [-1., -1., -1.,  ...,  1.,  1.,  1.],
          [-1., -1., -1.,  ...,  1.,  1.,  1.]],

         [[ 1.,  1.,  1.,  ...,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  ...,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  ...,  1.,  1.,  1.],
          ...,
          [-1., -1., -1.,  ...,  1.,  1.,  1.],
          [-1., -1., -1.,  ...,  1.,  1.,  1.],
          [-1., -1., -1.,  ...,  1.,  1.,

KeyboardInterrupt: 

In [24]:
save_image(perturbed_data, "final")

In [29]:
diff = perturbed_data.to('cpu')-original_data

In [None]:
show_images(diff)

In [None]:
show_images(perturbed_data)

In [None]:
my_attacker.get_Linfty_norm(perturbed_data.to('cpu')-original_data)

In [17]:

import matplotlib.pyplot as plt

def show_images(perturbed_data):
    # 检查 perturbed_data 是否是 4D Tensor
    if len(perturbed_data.shape) != 4 or perturbed_data.shape[1] != 3:
        raise ValueError("Input tensor must have shape [4, 3, 512, 512]")

    # 转换到CPU并转换为numpy数组
    images = perturbed_data.cpu().numpy()
    
    # 创建一个4x1的图像显示框架
    fig, axs = plt.subplots(1, 4, figsize=(20, 5))
    
    for i in range(4):
        img = images[i].transpose(1, 2, 0)  # 调整维度为 H x W x C
        axs[i].imshow(img.astype('uint8'))  # 确保数据类型为uint8以显示RGB图片
        axs[i].axis('off')  # 关闭坐标轴显示
    
    plt.show()

In [None]:
(perturbed_data[2]>255).all()

In [None]:
show_images(original_data)

In [None]:
show_images(perturbed_data)

In [97]:
save_image(perturbed_data, "final")