# Hidden Characters

Hi! Welcome to the official colab demo for our demo "Diffusion Illusions: Hiding Images in Plain Sight". [https://ryanndagreat.github.io/Diffusion-Illusions/](https://ryanndagreat.github.io/Diffusion-Illusions/)

This project was inspired by our paper "Peekaboo: Text to Image Diffusion Models are Zero-Shot Segmentors". The Peekaboo project website: [https://ryanndagreat.github.io/peekaboo/](https://ryanndagreat.github.io/peekaboo/)

Instructions:

0. Go to the Runtime menu, and make sure this notebook is using GPU!
1. Run the top 2 code cells (one cleans colab's junk and downloads the source code, while the other installs python packages)
2. Click 'Runtime', then 'Restart Runtime'. You need to do this the first time you open this notebook to avoid weird random errors from the pip installations.
3. Run code cells to load stable diffusion. The first time you run it it will take a few minutes to download; subsequent times won't take long at all though.
4. Run all the cells below that, and customize prompt_w, prompt_x, prompt_y, and prompt_z!
5. Take the result top_image and bottom_image, print them out, and shine a backlight through them like shown in the Diffusion Illusion website (link above!)

I may also create a YouTube tutorial if there's interest. Let me know if this would be helpful!

This notebook was written by Ryan Burgert. Feel free to reach out to me at rburgert@cs.stonybrook.edu if you have any questions! 

In [None]:
%%bash
if [ ! -d ".git" ]; then 
    rm -rf * .*; #Get rid of Colab's default junk files
    git clone -b master https://github.com/RyannDaGreat/Diffusion-Illusions .
fi

In [None]:
%pip install --upgrade -r requirements.txt
%pip install rp --upgrade
# You may need to restart the runtime after installing these
# I'm not sure why this helps, but all sorts of weird random errors pop up in Colab if you don't

In [None]:
import numpy as np
import rp
import torch
import torch.nn as nn
import torch.nn.functional as F
import source.stable_diffusion as sd
from easydict import EasyDict
from source.learnable_textures import LearnableImageFourier
from source.stable_diffusion_labels import NegativeLabel
from itertools import chain
import time

## QRCode模式：上传目标图片

如果你设置了 `USE_IMAGE_FOR_Z = True`，请在这个cell中上传你的目标图片（比如QRCode）。

**使用方法：**
1. 运行下面的代码cell
2. 点击上传按钮，选择你的图片文件
3. 图片会被自动调整大小并转换为正确的格式


In [None]:
# ===== 上传目标图片（仅在USE_IMAGE_FOR_Z=True时使用） =====
# 注意：这个cell需要在device定义之后运行（即运行Cell 9之后）
target_image_z = None

if USE_IMAGE_FOR_Z:
    try:
        from google.colab import files
        from PIL import Image
        import io
        
        print("请上传你的目标图片（例如QRCode）：")
        uploaded = files.upload()
        
        if uploaded:
            # 获取上传的文件
            file_name = list(uploaded.keys())[0]
            image_bytes = uploaded[file_name]
            
            # 加载图片
            pil_image = Image.open(io.BytesIO(image_bytes))
            
            # 转换为RGB格式（如果是RGBA或其他格式）
            if pil_image.mode != 'RGB':
                pil_image = pil_image.convert('RGB')
            
            # 转换为numpy数组并调整大小到256x256（与learnable_image的尺寸匹配）
            target_image_z = np.array(pil_image)
            target_image_z = rp.resize_image(target_image_z, (256, 256))
            
            # 转换为正确的格式：先确保是RGB格式，再转换为float，最后转换为torch tensor
            # 注意：device需要在Cell 9中定义
            if 'device' in globals():
                # 步骤1: 确保是RGB格式（HWC格式）
                target_image_z = rp.as_rgb_image(target_image_z)
                # 步骤2: 转换为float格式（值在[0,1]之间），仍然是numpy数组（HWC）
                target_image_z = rp.as_float_image(target_image_z)
                
                # QRCode预处理：转换为黑白（二值化），确保与黑白prompts匹配
                # 对于QRCode，我们需要高对比度、黑白分明，与前四张黑白图片风格一致
                if target_image_z.ndim == 3:
                    # 步骤1: 转换为灰度图（如果QRCode是彩色的，取RGB平均值）
                    gray = np.mean(target_image_z, axis=2, keepdims=True)
                    gray = np.repeat(gray, 3, axis=2)
                    
                    # 步骤2: 归一化到[0,1]范围
                    gray_min = gray.min()
                    gray_max = gray.max()
                    if gray_max > gray_min:
                        gray_normalized = (gray - gray_min) / (gray_max - gray_min)
                    else:
                        gray_normalized = gray
                    
                    # 步骤3: 强二值化（阈值化），转换为纯黑白
                    # 使用自适应阈值或固定阈值，确保QRCode清晰
                    try:
                        from skimage.filters import threshold_otsu
                        # 尝试使用Otsu算法自动计算最佳阈值
                        threshold = threshold_otsu(gray_normalized[:,:,0])
                        print(f"✓ 使用Otsu算法计算阈值: {threshold:.3f}")
                    except:
                        # 如果失败，使用固定阈值
                        threshold = 0.5
                        print("✓ 使用固定阈值: 0.5")
                    
                    # 强二值化：直接转换为0或1，不保留中间值
                    target_image_z = (gray_normalized[:,:,0] > threshold).astype(np.float32)
                    target_image_z = np.stack([target_image_z, target_image_z, target_image_z], axis=2)
                    
                    # 可选：形态学操作，清理噪声（可选，但有助于QRCode清晰度）
                    try:
                        from scipy import ndimage
                        # 轻微的开运算，去除小噪点
                        target_image_z = ndimage.binary_opening(target_image_z[:,:,0], structure=np.ones((2,2))).astype(np.float32)
                        target_image_z = np.stack([target_image_z, target_image_z, target_image_z], axis=2)
                        print("✓ 已应用形态学操作清理噪声")
                    except:
                        pass  # 如果scipy不可用，跳过
                    
                    print("✓ QRCode已转换为纯黑白高对比度格式，与前四张图片风格匹配")
                
                # 步骤3: 转换为torch tensor（CHW格式）并移动到device
                target_image_z = rp.as_torch_image(target_image_z).to(device)
                
                print(f"✓ 图片已加载，尺寸: {target_image_z.shape}")
                print("预览目标图片：")
                rp.display_image(rp.as_numpy_image(target_image_z))
            else:
                print("⚠️ 请先运行Cell 9来定义device变量")
                target_image_z = None
                USE_IMAGE_FOR_Z = False
        else:
            print("⚠️ 未上传图片，将使用文本prompt模式")
            USE_IMAGE_FOR_Z = False
    except ImportError:
        print("⚠️ 不在Colab环境中，无法上传图片。将使用文本prompt模式")
        USE_IMAGE_FOR_Z = False
else:
    print("ℹ️ QRCode模式未启用，跳过图片上传")


In [None]:
#ONLY GOOD PROMPTS HERE
example_prompts = rp.load_yaml_file('source/example_prompts.yaml')
print('Available example prompts:', ', '.join(example_prompts))

# ===== QRCode模式：使用图片作为prompt_z的目标 =====
# 如果你想使用自己的图片（比如QRCode）作为overlay后的目标图像，请设置 USE_IMAGE_FOR_Z = True
# 然后在上传图片的cell中上传你的图片
USE_IMAGE_FOR_Z = False  # 设置为True以启用图片模式

# QRCode模式：使用有语义内容但风格匹配QRCode的prompts
# 保持illusion效果：每张图片都有独立的语义解释，但风格上强调高对比度、黑白
# 这样既保持了"隐藏"的illusion效果，又能让overlay后清晰显示QRCode
if USE_IMAGE_FOR_Z:
    # QRCode模式推荐prompts：有具体内容，但风格强调高对比度、黑白、简洁
    # 参考论文中的例子：四张"playground"图片叠加后显示QRCode
    # 这里我们使用不同的语义内容，但都强调高对比度、黑白风格
    prompt_a = "playground scene, high contrast, black and white photography, monochrome"
    prompt_b = "urban building, high contrast, black and white, stark lighting"
    prompt_c = "geometric architecture, high contrast, monochrome, simple composition"
    prompt_d = "abstract landscape, high contrast, black and white, minimalist"
    prompt_z = "QR code"  # 这个不会被使用，但需要定义
    print("⚠️ QRCode模式：已自动使用有语义内容但风格匹配QRCode的prompts")
    print("   每张图片都有独立的语义解释（保持illusion效果）")
    print("   但风格上强调高对比度、黑白，有助于overlay后显示QRCode")
    print("   如需自定义，可以修改prompt_a, prompt_b, prompt_c, prompt_d")
    print("   建议在prompts中添加：high contrast, black and white, monochrome 等关键词")
else:
    # 默认模式：使用示例prompts
    #These prompts are all strings - you can replace them with whatever you want! By default it lets you choose from example prompts
    prompt_a, prompt_b, prompt_c, prompt_d, prompt_z = rp.gather(example_prompts, 'miku froggo lipstick kitten_in_box darth_vader'.split())
    #Prompts a,b,c,d are the normal looking images
    #Prompt z is the hidden image you get when you overlay them all on top of each other

negative_prompt = ''

print()
print('Negative prompt:',repr(negative_prompt))
print()
print('Chosen prompts:')
print('    prompt_a =', repr(prompt_a))
print('    prompt_b =', repr(prompt_b))
print('    prompt_c =', repr(prompt_c))
print('    prompt_d =', repr(prompt_d))
print('    prompt_z =', repr(prompt_z))
print()
if USE_IMAGE_FOR_Z:
    print('⚠️ QRCode模式已启用：prompt_z将使用上传的图片作为目标')
    print('⚠️ 已自动切换到有语义内容但风格匹配QRCode的prompts（a, b, c, d）')
    print('   每张图片都有独立的语义解释，保持illusion效果')
    print('   但风格上强调高对比度、黑白，有助于overlay后显示QRCode')
    print('   如需自定义prompts，请在设置USE_IMAGE_FOR_Z之后修改prompt_a, prompt_b, prompt_c, prompt_d')
    print('   建议在prompts中添加：high contrast, black and white, monochrome 等关键词')
else:
    print('ℹ️ 当前使用文本prompt模式。要启用QRCode模式，请设置 USE_IMAGE_FOR_Z = True')
    print('   提示：QRCode模式下建议在prompts中添加"high contrast, black and white"等关键词')

# New Section

In [None]:
if 's' not in dir():
    model_name="CompVis/stable-diffusion-v1-4"
    gpu='cuda:0'
    s=sd.StableDiffusion(gpu,model_name)
device=s.device

In [None]:
label_a = NegativeLabel(prompt_a,negative_prompt)
label_b = NegativeLabel(prompt_b,negative_prompt)
label_c = NegativeLabel(prompt_c,negative_prompt)
label_d = NegativeLabel(prompt_d,negative_prompt)
label_z = NegativeLabel(prompt_z,negative_prompt)

In [None]:
#Image Parametrization and Initialization (this section takes vram)

#Select Learnable Image Size (this has big VRAM implications!):
#Note: We use implicit neural representations for better image quality
#They're previously used in our paper "TRITON: Neural Neural Textures make Sim2Real Consistent" (see tritonpaper.github.io)
# ... and that representation is based on Fourier Feature Networks (see bmild.github.io/fourfeat)
learnable_image_maker = lambda: LearnableImageFourier(height=256, width=256, hidden_dim=256, num_features=128).to(s.device); SIZE=256
# learnable_image_maker = lambda: LearnableImageFourier(height=512,width=512,num_features=256,hidden_dim=256,scale=20).to(s.device);SIZE=512

image_a=learnable_image_maker()
image_b=learnable_image_maker()
image_c=learnable_image_maker()
image_d=learnable_image_maker()

In [None]:
CLEAN_MODE = True # If it's False, we augment the images by randomly simulating how good a random printer might be when making the overlays...

def simulate_overlay(a,b,c,d):
    if CLEAN_MODE:
        exp=1
        brightness=3
        black=0
    else:
        exp=rp.random_float(.5,1)
        brightness=rp.random_float(1,5)
        black=rp.random_float(0,.5)
        bottom=rp.blend(bottom,black,rp.random_float())
        top=rp.blend(top,black,rp.random_float())
    return (a**exp * b**exp *c**exp * d**exp * brightness).clamp(0,99).tanh()

learnable_image_a=lambda: image_a()
learnable_image_b=lambda: image_b()
learnable_image_c=lambda: image_c()
learnable_image_d=lambda: image_d()
learnable_image_z=lambda: simulate_overlay(image_a(), image_b(), image_c(), image_d())

params=chain(
    image_a.parameters(),
    image_b.parameters(),
    image_c.parameters(),
    image_d.parameters(),
)
# QRCode模式：使用更小的学习率以确保稳定训练
if USE_IMAGE_FOR_Z and 'target_image_z' in globals() and target_image_z is not None:
    optim=torch.optim.SGD(params,lr=3e-5)  # QRCode模式使用更小的学习率
    print("⚠️ QRCode模式：已降低学习率到3e-5，以提高训练稳定性")
else:
    optim=torch.optim.SGD(params,lr=1e-4)  # 默认学习率

In [None]:
labels=[label_a, label_b, label_c, label_d, label_z]
learnable_images=[learnable_image_a,learnable_image_b,learnable_image_c,learnable_image_d,learnable_image_z]

#The weight coefficients for each prompt. For example, if we have [1,1,1,1,5], then the hidden prompt (prompt_z) will be prioritized
# QRCode模式：增加label_z的权重以确保QRCode效果更好
if USE_IMAGE_FOR_Z and 'target_image_z' in globals() and target_image_z is not None:
    weights=[1,1,1,1,10]  # QRCode模式下，大幅增加label_z的权重（从3增加到10）
    print("⚠️ QRCode模式：已增加label_z权重为10，以更好地匹配QRCode")
    print("   这将使训练时更频繁地优化QRCode目标")
else:
    weights=[1,1,1,1,1]  # 默认权重

weights=rp.as_numpy_array(weights)
weights=weights/weights.sum()
weights=weights*len(weights)

In [None]:
#For saving a timelapse
ims=[]

In [None]:
def get_display_image():
    return rp.tiled_images(
        [
            *[rp.as_numpy_image(image()) for image in learnable_images[:-1]],
            rp.as_numpy_image(learnable_image_z()),
        ],
        length=len(learnable_images),
        border_thickness=0,
    )

In [None]:
NUM_ITER=10000

#Set the minimum and maximum noise timesteps for the dream loss (aka score distillation loss)
s.max_step=MAX_STEP=990
s.min_step=MIN_STEP=10 

display_eta=rp.eta(NUM_ITER, title='Status: ')

DISPLAY_INTERVAL = 200

print('Every %i iterations we display an image in the form [image_a, image_b, image_c, image_d, image_z] where'%DISPLAY_INTERVAL)
print('    image_z = image_a * image_b * image_c * image_d')
print()
print('Interrupt the kernel at any time to return the currently displayed image')
print('You can run this cell again to resume training later on')
print()
print('Please expect this to take hours to get good images (especially on the slower Colab GPU\'s! The longer you wait the better they\'ll be')

try:
    # QRCode模式：确保每次迭代都优化QRCode
    qrcode_mode = (USE_IMAGE_FOR_Z and 'target_image_z' in globals() and target_image_z is not None)
    
    for iter_num in range(NUM_ITER):
        display_eta(iter_num) #Print the remaining time

        preds=[]
        
        # QRCode模式：改进的训练策略
        # 每次迭代都优化label_z，然后随机优化其他labels
        if qrcode_mode:
            # 首先优化QRCode（label_z）- 每次迭代都执行
            pred_image = learnable_image_z()[None]  # [1, 3, H, W]
            target_resized = F.interpolate(target_image_z[None], size=(pred_image.shape[2], pred_image.shape[3]), mode='bilinear', align_corners=False)
            
            # 使用更强的L2 loss和L1 loss
            l2_loss = ((pred_image - target_resized) ** 2).mean()
            l1_loss = torch.abs(pred_image - target_resized).mean()
            # 组合loss：L2为主，L1辅助，权重大幅增加
            image_loss = (l2_loss * 15.0 + l1_loss * 3.0) * weights[4]  # weights[4]是label_z的权重
            image_loss.backward(retain_graph=True)
            
            # 然后随机优化其他labels（a, b, c, d）
            other_labels = list(zip([label_a, label_b, label_c, label_d], 
                                   [learnable_image_a, learnable_image_b, learnable_image_c, learnable_image_d],
                                   [weights[0], weights[1], weights[2], weights[3]]))
            # 每次迭代随机选择1-2个其他labels进行优化
            num_other = np.random.randint(1, 3)  # 随机选择1或2个
            selected = np.random.choice(len(other_labels), size=min(num_other, len(other_labels)), replace=False)
            
            for idx in selected:
                label, learnable_image, weight = other_labels[idx]
                pred=s.train_step(
                    label.embedding,
                    learnable_image()[None],
                    noise_coef=.1*weight,guidance_scale=60,
                )
                preds+=list(pred)
        else:
            # 正常模式：随机选择label
            label_items = list(zip(labels,learnable_images,weights))
            for idx, (label,learnable_image,weight) in enumerate(rp.random_batch(label_items, batch_size=1)):
                pred=s.train_step(
                    label.embedding,
                    learnable_image()[None],
                    noise_coef=.1*weight,guidance_scale=60,
                )
                preds+=list(pred)

        with torch.no_grad():
            if iter_num and not iter_num%(DISPLAY_INTERVAL*50):
                #Wipe the slate every 50 displays so they don't get cut off
                from IPython.display import clear_output
                clear_output()

            if not iter_num%DISPLAY_INTERVAL:
                im = get_display_image()
                ims.append(im)
                rp.display_image(im)

        optim.step()
        optim.zero_grad()
except KeyboardInterrupt:
    print()
    print('Interrupted early at iteration %i'%iter_num)
    im = get_display_image()
    ims.append(im)
    rp.display_image(im)

In [None]:
print('Image A')
rp.display_image(rp.as_numpy_image(learnable_image_a()))

print('Image B')
rp.display_image(rp.as_numpy_image(learnable_image_b()))

print('Image C')
rp.display_image(rp.as_numpy_image(learnable_image_c()))

print('Image D')
rp.display_image(rp.as_numpy_image(learnable_image_d()))

print('Image Z')
rp.display_image(rp.as_numpy_image(learnable_image_z()))

In [None]:
def save_run(name):
    folder="untracked/hidden_character_runs/%s"%name
    if rp.path_exists(folder):
        folder+='_%i'%time.time()
    rp.make_directory(folder)
    ims_names=['ims_%04i.png'%i for i in range(len(ims))]
    with rp.SetCurrentDirectoryTemporarily(folder):
        rp.save_images(ims,ims_names,show_progress=True)
    print()
    print('Saved timelapse to folder:',repr(folder))
    
save_run('untitled') #You can give it a good custom name if you want!