In [None]:
import torch
import argparse
import sys
import os
infinity2b = False

# Add Infinity directory to Python path
sys.path.append(os.path.join(os.getcwd(), '../'))
sys.path.append(os.path.join(os.getcwd(), '../Infinity'))

weights_path = '/path/infinity'
vae_path = '/path/weights/infinity/infinity_vae_d32reg.pth'
model_path = '/path/weights/infinity/infinity_2b_reg.pth'
text_encoder_ckpt = '/path/weights/infinity/models--google--flan-t5-xl/snapshots/7d6315df2c2fb742f0f5b556879d730926ca9001' 

def create_infinity_args():
    """Create args namespace for infinity configuration"""
    args = argparse.Namespace()
    
    # Architecture settings
    args.architecture = "infinity"
    args.pn = "1M" # , 1M = 1024x1024, 0.25M = 512x512, 0.06M = 256x256
    if infinity2b:
        args.vae_path = f'{weights_path}/infinity_vae_d32reg.pth'
        args.model_path = f'{weights_path}/infinity_2b_reg.pth'
        args.checkpoint_type = 'torch'
        args.vae_type = 32
        args.model_type = "infinity_2b"
        args.apply_spatial_patchify = 0
    else:
        args.vae_path = f'{weights_path}/infinity/infinity_vae_d56_f8_14_patchify.pth'
        args.model_path = f'{weights_path}/infinity/infinity_8b_weights'  # 8.4GB
        args.checkpoint_type = 'torch_shard'
        args.vae_type = 14
        args.model_type = "infinity_8b"
        args.apply_spatial_patchify = 1


    args.use_scale_schedule_embedding = 0
    args.use_bit_label = 1
    args.cfg = "3"
    args.tau = 0.5
    args.rope2d_normalized_by_hw = 2
    args.add_lvl_embeding_only_first_block = 1
    args.rope2d_each_sa_layer = 1
    args.text_encoder_ckpt = f'{weights_path}/models--google--flan-t5-xl/snapshots/7d6315df2c2fb742f0f5b556879d730926ca9001'
    args.text_channels = 2048
    args.sampling_per_bits = 1
    args.enable_positive_prompt = 0
    args.decode_per_scale = False  # Decode per scale, if True, will decode each scale separately
    args.cache_dir = "/dev/shm"
    args.enable_model_cache = 0
    args.bf16 = 1 
    args.use_flex_attn = 0
    args.cfg_insertion_layer = 0  # CFG insertion layer, 0 means no CFG insertion
    # Output directory
    args.out_dir = "./tmp"
    
    # Common arguments
    args.seed = 0
    args.batch_size = 2
    args.watermark_scales = 2  # 0: No watermark, 1: Only last scale, 2: Apply on all scales, 3: up to the 9th scale, 4: from the 10th scale up, 5: [3,4,5]
    args.watermark_delta = 2  # Bias model towards green set by this delta
    args.watermark_context_width = 2
    args.watermark_gen_image = 1  # 0: Only detect watermark given save_file path, 1: generate image based on prompt and detect watermark
    args.watermark_count_bit_loss_after_reencoding = 1  # reencodes the generated image and counts the bits overlap
    args.watermark_count_bit_flip = 1 # counts how many bits are flipped compared to the non-watermarked bits
    args.watermark_method = '2-bit_pattern'  # 2-bit_pattern
    args.set = "01,10" # Which green-list bit patterns to use for watermarking
    args.dataset_path = '/path/datasets/mscoco2014val'
    args.watermark_add_noise = 0  # Add noise to the watermarked image before detection
    args.watermark_remove_duplicates = 0   
    return args

# Create the args namespace
args = create_infinity_args()

In [None]:

from tools.run_infinity import load_tokenizer
from architecture_wrapper import get_architecture, get_vae
vae_wrapper = get_vae(args)
text_tokenizer, text_encoder = load_tokenizer(t5_path = args.text_encoder_ckpt)
model = get_architecture(args,vae_wrapper=vae_wrapper)


In [None]:
import detect_watermark
from extended_watermark_processor import WatermarkDetector
watermark_detector = WatermarkDetector(
    vocab=[0,1],
    gamma=0.5,
    delta=2,
    device="cuda",
    z_threshold=4.0,
    ignore_repeated_ngrams=False,
    green_list='10,01'
)


In [None]:
from PIL import Image
from helper import save_single_image
prompt = "A dog"

watermark_inference = detect_watermark.WatermarkInference(args, vae_wrapper)
metrics,bits,img = model.gen_img([prompt], vae_wrapper.vae, watermark_inference)
display(Image.fromarray(img.cpu().numpy()))
save_single_image(img, 'tmp.png')
detect_watermark.detect(args, "tmp.png", watermark_detector=watermark_detector, vae_wrapper=vae_wrapper, detect_on_each_scale=False)
