In [1]:
from diffusers_wrapper import StableDiffusion3TextToImage, FluxTextToImage, StableDiffusion2TextToImage, StableDiffusionXLPipelineTextToImage
import os

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
num_images = 4
main_output_path = 'output'
model_name = 'flux-schnell' # 'sd2' or 'sd3' or 'sdxl' or 'flux-dev' or 'flux-schnell'
max_sequence_lengths ={
    'sd2': 77,
    'sd3': 77,
    'sdxl': 256,
    'flux-dev': 512,
    'flux-schnell': 256
}
max_sequence_length = max_sequence_lengths[model_name]

In [3]:

if 'flux' in model_name:
    ckpt_dir = ''
    model_class = FluxTextToImage(model_name=model_name, ckpt_dir=ckpt_dir, num_images=num_images, 
                                  max_sequence_length=max_sequence_length)
elif model_name == 'sd2':
    model_name = model_name
    ckpt_dir = ''
    model_class = StableDiffusion2TextToImage(model_name=model_name, ckpt_dir=ckpt_dir, num_images=num_images)
elif model_name == 'sd3':
    model_name = model_name
    ckpt_dir = ''
    model_class = StableDiffusion3TextToImage(model_name=model_name, ckpt_dir=ckpt_dir, num_images=num_images, 
                                              max_sequence_length=max_sequence_length) 
elif model_name == 'sdxl':
    model_name = model_name
    ckpt_dir = ''
    model_class = StableDiffusionXLPipelineTextToImage(model_name=model_name, ckpt_dir=ckpt_dir, num_images=num_images, 
                                                       max_sequence_length=max_sequence_length)
else:
    raise ValueError("Model name not found")


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  5.32it/s]it/s]
Loading pipeline components...: 100%|██████████| 7/7 [00:02<00:00,  3.26it/s]


In [18]:
skip_layers = 0
return_grids = True
# full is the full range of tokens with padas - remove if you dont need images from this range
# tokens is the range of tokens without padas - remove if you dont need images from this range
# specific_tokens - if you want images from specific tokens, 
ranges_to_keep = ['full', 'tokens', 'pads', 'specific_token_idx_to_keep_per_prompt']

prompts = [
    'kids playing in the play ground',
]

tokenizers = model_class.get_tokenizers()
tokenizer_3 = tokenizers['tokenizer_2']
tokenized_prompts = [tokenizer_3.encode(prompt) for prompt in prompts]
for prompt in tokenized_prompts:
    for subtoken_idx, subtoken in enumerate(tokenized_prompts[0]):
        print(subtoken_idx, tokenizer_3.decode(subtoken))

0 kids
1 playing
2 in
3 the
4 play
5 ground
6 </s>


In [21]:

# each list would results in generating images fro, the tokens in the list only. If you want to generate image from each prompt, pass a list for each prompt
# specific_tokens_per_prompt = [
#     ['kids', 'in']
# ]

specific_token_idx_to_keep_per_prompt = [
    [5]
]



In [22]:
# Inference
for prompt, specific_token_idx in zip(prompts, specific_token_idx_to_keep_per_prompt):
    print("generating images for prompt:", prompt)         
    output_path = f'{main_output_path}/{model_name}/{prompt}'
    os.makedirs(output_path, exist_ok=True)
    if 'flux' in model_name:
        grids = model_class.forward(prompt, num_images=model_class.num_images, 
                    output_path=output_path, 
                    save_grid=True, 
                    save_per_image=False, 
                    return_grids=return_grids,
                    skip_layers=skip_layers,
                    ranges_to_keep=ranges_to_keep,
                    specific_token_idx_to_keep_per_prompt=specific_token_idx,
                    # specific_tokens=specific_tokens,
                    # pad_encoders=pad_encoders,
                    # zero_paddings=False, 
                    # replace_with_pads=True,
                    # turn_attention_off=False,
                    )
    else:
        pad_encoders = [1,2] # 1,2 for clips, 3 for T5
        grids = model_class.forward(prompt, num_images=model_class.num_images, 
                            output_path=output_path, 
                            save_grid=True, 
                            save_per_image=False, 
                            return_grids=return_grids,
                            skip_layers=skip_layers,
                            ranges_to_keep=ranges_to_keep,
                            specific_token_idx_to_keep_per_prompt=specific_token_idx_to_keep_per_prompt,
                            # specific_tokens=specific_tokens,
                            pad_encoders=pad_encoders,
                            # zero_paddings=False, 
                            # replace_with_pads=True,
                            # turn_attention_off=False,
                            )
    # print("done generating images for prompt:", prompt)
    # for grid in grids:
    #     grid.show()

generating images for prompt: kids playing in the play ground
Validated skip_layers: [0, 0]
ranges_to_keep ['full', 'tokens', 'pads', 'specific_token_idx_to_keep_per_prompt']
Getting range for full
Getting range for tokens
Getting range for pads
Getting range for specific_token_idx_to_keep_per_prompt
st_ground: [0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159,

  0%|          | 0/4 [00:00<?, ?it/s]

100%|██████████| 4/4 [00:01<00:00,  2.12it/s]


Image grid for full already exists, skipping
text_ids shape [tokenizer - CLIP]:  torch.Size([1, 77])
Original CLIP pooling code
text_ids shape [tokenizer_2 - T5]:  torch.Size([1, 256])
Skiping num layers:  0
skip_tokens (T5):  [7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180

100%|██████████| 4/4 [00:01<00:00,  2.12it/s]


Image grid for tokens already exists, skipping
text_ids shape [tokenizer - CLIP]:  torch.Size([1, 77])
Original CLIP pooling code
text_ids shape [tokenizer_2 - T5]:  torch.Size([1, 256])
Skiping num layers:  0
skip_tokens (T5):  [0, 1, 2, 3, 4, 5, 6]
replacing 0 replacing 1 replacing 2 replacing 3 replacing 4 replacing 5 replacing 6 Keep token_idx:  7 Keep token_idx:  8 Keep token_idx:  9 Keep token_idx:  10 Keep token_idx:  11 Keep token_idx:  12 Keep token_idx:  13 Keep token_idx:  14 Keep token_idx:  15 Keep token_idx:  16 Keep token_idx:  17 Keep token_idx:  18 Keep token_idx:  19 Keep token_idx:  20 Keep token_idx:  21 Keep token_idx:  22 Keep token_idx:  23 Keep token_idx:  24 Keep token_idx:  25 Keep token_idx:  26 Keep token_idx:  27 Keep token_idx:  28 Keep token_idx:  29 Keep token_idx:  30 Keep token_idx:  31 Keep token_idx:  32 Keep token_idx:  33 Keep token_idx:  34 Keep token_idx:  35 Keep token_idx:  36 Keep token_idx:  37 Keep token_idx:  38 Keep token_idx:  39 Keep tok

100%|██████████| 4/4 [00:01<00:00,  2.12it/s]


Image grid for pads already exists, skipping
text_ids shape [tokenizer - CLIP]:  torch.Size([1, 77])
Original CLIP pooling code
text_ids shape [tokenizer_2 - T5]:  torch.Size([1, 256])
Skiping num layers:  0
skip_tokens (T5):  [0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 

100%|██████████| 4/4 [00:01<00:00,  2.11it/s]


Creating image grid for 4 images
Image grid saved to output/flux-schnell/kids playing in the play ground/st_ground/st_ground_[0, 0].png
