In [1]:
import torch
import warnings
warnings.filterwarnings("ignore")

from transformers import CLIPTokenizer, T5TokenizerFast

In [2]:
# tokenizers
pretrained_model_name = "stabilityai/stable-diffusion-3.5-medium"
revision = None

tokenizer_clip_l = CLIPTokenizer.from_pretrained(
    pretrained_model_name,
    subfolder="tokenizer",
    revision=revision,
)
tokenizer_clip_g = CLIPTokenizer.from_pretrained(
    pretrained_model_name,
    subfolder="tokenizer_2",
    revision=revision,
)
tokenizer_t5 = T5TokenizerFast.from_pretrained(
    pretrained_model_name,
    subfolder="tokenizer_3",
    revision=revision,
)

You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers


In [3]:
prompts = ['I am Bob', 'Bob is a myself']

In [4]:
def tokenize_prompt(tokenizer, prompt):
    text_inputs = tokenizer(
        prompt,
        padding="max_length",
        max_length=77,
        truncation=True,
        return_tensors="pt",
    )
    text_input_ids = text_inputs.input_ids
    return text_input_ids

In [5]:
tokens_clip_l = tokenizer_clip_l(prompts, return_tensors="pt", padding=True)
tokens_clip_g = tokenizer_clip_g(prompts, return_tensors="pt", padding=True)
tokens_t5 = tokenizer_t5(prompts, return_tensors="pt", padding=True)

In [6]:
print(tokens_clip_l['input_ids'])
print(tokens_clip_l['input_ids'].shape)
print(tokens_clip_g['attention_mask'])
print(tokens_clip_g['attention_mask'].shape)

tensor([[49406,   328,   687,  4423, 49407, 49407],
        [49406,  4423,   533,   320,  3245, 49407]])
torch.Size([2, 6])
tensor([[1, 1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1, 1]])
torch.Size([2, 6])


In [7]:
from sd3.stability.other_impls import SD3Tokenizer

tokenizer = SD3Tokenizer()

tokens = tokenizer.tokenize_with_weights(prompts[0])
print(tokens['l'])

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


[[(49406, 1.0), (328, 1.0), (687, 1.0), (4423, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1.0), (49407, 1

In [6]:
from sd3.sd3_train import ClipL, ClipG, T5XXL

@torch.no_grad()
def load_text_encoders(clipl_weights, clipg_weights, t5_weights):
    clip_l = ClipL(clipl_weights)
    clip_g = ClipG(clipg_weights)
    t5 = T5XXL(t5_weights)
    return clip_l, clip_g, t5

clipl_weights = "models/official/clip_l.safetensors"
clipg_weights = "models/official/clip_g.safetensors"
t5_weights = "models/official/t5xxl.safetensors"
clip_l, clip_g, t5 = load_text_encoders(clipl_weights, clipg_weights, t5_weights)

Skipping key 'shared.weight' in safetensors file as 'shared' does not exist in python model


In [63]:
token_l = tokens["l"].copy()
token_l = torch.tensor(token_l)
print(token_l.shape)
# repeat token_l in the batch dimension
token_l = token_l.repeat(2, 1, 1)
print(token_l.shape)
# token_l to list
token_l = token_l.tolist()
a = clip_l.model.encode_token_weights(token_l)
print(a[0].shape)

torch.Size([1, 77, 2])
torch.Size([2, 77, 2])
torch.Size([1, 77, 768])


In [62]:
a = clip_l.model.encode_token_weights(token_l)
print(a[0].shape)

torch.Size([1, 77, 768])


In [7]:
def _encode_prompt_with_t5(
    text_encoder,
    tokenizer,
    max_sequence_length,
    prompt=None,
    num_images_per_prompt=1,
    device=None,
    text_input_ids=None,
):
    prompt = [prompt] if isinstance(prompt, str) else prompt
    batch_size = len(prompt)

    if tokenizer is not None:
        text_inputs = tokenizer(
            prompt,
            padding="max_length",
            max_length=max_sequence_length,
            truncation=True,
            add_special_tokens=True,
            return_tensors="pt",
        )
        text_input_ids = text_inputs.input_ids
    else:
        if text_input_ids is None:
            raise ValueError("text_input_ids must be provided when the tokenizer is not specified")

    prompt_embeds = text_encoder(text_input_ids.to(device))[0]

    dtype = text_encoder.dtype
    prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)

    _, seq_len, _ = prompt_embeds.shape

    # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
    prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
    prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)

    return prompt_embeds

def _encode_prompt_with_clip(
    text_encoder,
    tokenizer,
    prompt: str,
    device=None,
    text_input_ids=None,
    num_images_per_prompt: int = 1,
):
    prompt = [prompt] if isinstance(prompt, str) else prompt
    batch_size = len(prompt)

    if tokenizer is not None:
        text_inputs = tokenizer(
            prompt,
            padding="max_length",
            max_length=77,
            truncation=True,
            return_tensors="pt",
        )

        text_input_ids = text_inputs.input_ids
    else:
        if text_input_ids is None:
            raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
    
    prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)

    pooled_prompt_embeds = prompt_embeds[0]
    prompt_embeds = prompt_embeds.hidden_states[-2]
    prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)

    _, seq_len, _ = prompt_embeds.shape
    # duplicate text embeddings for each generation per prompt, using mps friendly method
    prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
    prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)

    return prompt_embeds, pooled_prompt_embeds

def encode_prompt(
    text_encoders,
    tokenizers,
    prompt: str,
    max_sequence_length,
    device=None,
    num_images_per_prompt: int = 1,
    text_input_ids_list=None,
):
    prompt = [prompt] if isinstance(prompt, str) else prompt

    clip_tokenizers = tokenizers[:2]
    clip_text_encoders = text_encoders[:2]

    clip_prompt_embeds_list = []
    clip_pooled_prompt_embeds_list = []
    for i, (tokenizer, text_encoder) in enumerate(zip(clip_tokenizers, clip_text_encoders)):
        prompt_embeds, pooled_prompt_embeds = _encode_prompt_with_clip(
            text_encoder=text_encoder,
            tokenizer=tokenizer,
            prompt=prompt,
            device=device if device is not None else next(text_encoder.model.parameters()).device,
            num_images_per_prompt=num_images_per_prompt,
            text_input_ids=text_input_ids_list[i] if text_input_ids_list else None,
        )
        clip_prompt_embeds_list.append(prompt_embeds)
        clip_pooled_prompt_embeds_list.append(pooled_prompt_embeds)

    clip_prompt_embeds = torch.cat(clip_prompt_embeds_list, dim=-1)
    pooled_prompt_embeds = torch.cat(clip_pooled_prompt_embeds_list, dim=-1)

    t5_prompt_embed = _encode_prompt_with_t5(
        text_encoders[-1],
        tokenizers[-1],
        max_sequence_length,
        prompt=prompt,
        num_images_per_prompt=num_images_per_prompt,
        text_input_ids=text_input_ids_list[-1] if text_input_ids_list else None,
        device=device if device is not None else next(text_encoders[-1].model.parameters()).device, 
    )

    clip_prompt_embeds = torch.nn.functional.pad(
        clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
    )
    prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)

    return prompt_embeds, pooled_prompt_embeds

In [12]:
# max_sequence_length = 77

# prompt_embeds, pooled_prompt_embeds = encode_prompt(
#     text_encoders=[clip_l, clip_g, t5],
#     tokenizers=[None, None, None],
#     prompt=prompts,
#     max_sequence_length=max_sequence_length,
#     text_input_ids_list=[tokens_clip_l, tokens_clip_g, tokens_t5],
# )

In [None]:
class SD3Trainer:
    def __init__(
        self,
        tokenizer_clip_l,
        tokenizer_clip_g,
        tokenizer_t5,
        clip_l,
        clip_g,
        t5xxl,
    ):
        self.tokenizer_clip_l = tokenizer_clip_l
        self.clip_l = clip_l
        
        self.tokenizer_clip_g = tokenizer_clip_g
        self.clip_g = clip_g
        
        self.tokenizer_t5 = tokenizer_t5
        self.t5xxl = t5xxl
    
    def get_cond(self, prompts):        
        tokens = self.tokenize_with_weights(prompts)
        
        l_out, l_pooled = self.clip_l.model.encode_token_weights(tokens["l"])
        g_out, g_pooled = self.clip_g.model.encode_token_weights(tokens["g"])
        t5_out, t5_pooled = self.t5xxl.model.encode_token_weights(tokens["t5xxl"])
        
        lg_out = torch.cat([l_out, g_out], dim=-1)
        lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
        # return prompt_embeds and pooled_prompt_embeds
        return torch.cat([lg_out, t5_out], dim=-2), torch.cat(
            (l_pooled, g_pooled), dim=-1
        )
    
    def tokenize_with_weights(self, prompts):
        tokens = {}
        tokens['l'] = ...
        tokens['g'] = ...
        tokens['t5xxl'] = ...
        return tokens

sd3_trainer = SD3Trainer(tokenizer_clip_l, tokenizer_clip_g, tokenizer_t5, clip_l, clip_g, t5)