In [1]:
import warnings
from functools import partial
from typing import Dict, List, Optional, Union

import jax
import jax.numpy as jnp
import numpy as np
from flax.core.frozen_dict import FrozenDict
from flax.jax_utils import unreplicate
from flax.training.common_utils import shard
from packaging import version
from PIL import Image
from transformers import CLIPTokenizer, FlaxCLIPTextModel
import time

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
BATCH_SIZE = 2
DEVICES = 4
REPEATS = 2

In [3]:
prompt = "hello ladies and gentlemen today we are going to..."
prompt = [ prompt for i in range(BATCH_SIZE)]
print(prompt)

['hello ladies and gentlemen today we are going to...', 'hello ladies and gentlemen today we are going to...']


In [4]:
tokenizer = CLIPTokenizer.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="tokenizer")
text_encoder = FlaxCLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder", revision="flax", dtype=jnp.float16)

In [5]:
tokens = tokenizer(
    prompt,
    padding="max_length",
    max_length=tokenizer.model_max_length,
    truncation=True,
    return_tensors="np",
)

In [7]:
print(tokens['input_ids'].shape)
print(tokens['attention_mask'].shape)

(2, 77)
(2, 77)


In [6]:
init_time = time.time()
for i in range(REPEATS):
    embeds = text_encoder(output_hidden_states = True, **tokens)
finish_time = time.time()
print(f'Embedding time: {(finish_time - init_time) / BATCH_SIZE / REPEATS} seconds')

Embedding time: 2.6329848766326904 seconds


In [11]:
print(embeds[1].shape)

(2, 768)


In [15]:
print(len(embeds.hidden_states))

13


In [20]:
print(embeds.hidden_states[-2].shape)

(2, 77, 768)


In [31]:
print(embeds.pooler_output.shape)
print(embeds.last_hidden_state.shape)

(1, 768)
(1, 77, 768)


In [19]:
def encode(input_ids: jnp.array, attention_mask: jnp.array):
    return text_encoder(input_ids, attention_mask=attention_mask)[0]

def create_key(seed=0):
    return jax.random.PRNGKey(seed)

rng = create_key(0)
model = jax.pmap(encode, in_axes=(0, 0,))

In [29]:
print(jnp.array([tokens['input_ids'] for i in range(DEVICES)]).shape)

(4, 64, 77)


In [27]:
init_time = time.time()
for i in range(REPEATS):
    output = model(
        jnp.array([tokens['input_ids'] for i in range(DEVICES)]),
        jnp.array([tokens['attention_mask'] for i in range(DEVICES)]),
    )
    output.block_until_ready()
print(f"Time taken: {(time.time() - init_time) / DEVICES / BATCH_SIZE / REPEATS} seconds")

Time taken: 0.0952918529510498 seconds


In [20]:
print(tokens)

{'input_ids': array([[49406,  3306,  3431,   537, 11692,   721,   649,   631,  1245,
          531,   678, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
        49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
        49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
        49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
        49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
        49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
        49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
        49407, 49407, 49407, 49407, 49407],
       [49406,  3306,  3431,   537, 11692,   721,   649,   631,  1245,
          531,   678, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
        49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
        49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
        49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
        49407, 4940