In [1]:
import warnings
from functools import partial
from typing import Dict, List, Optional, Union

import jax
import jax.numpy as jnp
import numpy as np
from flax.core.frozen_dict import FrozenDict
from flax.jax_utils import unreplicate
from flax.training.common_utils import shard
from packaging import version
from PIL import Image
from transformers import CLIPTokenizer, FlaxCLIPTextModel
import time

  from .autonotebook import tqdm as notebook_tqdm


In [37]:
BATCH_SIZE = 2
DEVICES = 4
REPEATS = 2

In [39]:
prompt = "hello ladies and gentlemen today we are going to..."
prompt = [ prompt for i in range(BATCH_SIZE)]
print(prompt)

['hello ladies and gentlemen today we are going to...', 'hello ladies and gentlemen today we are going to...']


In [4]:
tokenizer = CLIPTokenizer.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="tokenizer")
text_encoder = FlaxCLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder", revision="flax", dtype=jnp.float16)



In [40]:
tokens = tokenizer(
    prompt,
    padding="max_length",
    max_length=tokenizer.model_max_length,
    truncation=True,
    return_tensors="np",
)

In [46]:
print(tokens['input_ids'].shape)
print(tokens['attention_mask'].shape)

(2, 77)
(2, 77)


In [None]:
list = 

In [41]:
init_time = time.time()
for i in range(REPEATS):
    embeds = text_encoder(**tokens)
finish_time = time.time()
print(f'Embedding time: {(finish_time - init_time) / BATCH_SIZE / REPEATS} seconds')

Embedding time: 2.6413785815238953 seconds


In [45]:
print(embeds[1].shape)

(2, 768)


In [13]:
print(type(embeds[0]))

<class 'jaxlib.xla_extension.ArrayImpl'>


In [30]:
print(embeds)

[[[-0.3884   0.02295 -0.05225 ... -0.49    -0.3066   0.0674 ]
  [ 1.456    0.1726  -1.575   ... -0.665   -0.65    -0.3667 ]
  [ 0.8887  -0.2103  -0.275   ...  0.1096  -1.028   -0.496  ]
  ...
  [ 0.3267   0.1011   0.45    ... -0.579   -0.4492  -0.2013 ]
  [ 0.1454   0.07526  0.2927  ... -0.5166  -0.4292  -0.1357 ]
  [ 1.166    0.4116   1.962   ... -0.946    0.987   -0.485  ]]

 [[-0.3884   0.02295 -0.05225 ... -0.49    -0.3066   0.0674 ]
  [ 1.456    0.1726  -1.575   ... -0.665   -0.65    -0.3667 ]
  [ 0.8887  -0.2103  -0.275   ...  0.1096  -1.028   -0.496  ]
  ...
  [ 0.3267   0.1011   0.45    ... -0.579   -0.4492  -0.2013 ]
  [ 0.1454   0.07526  0.2927  ... -0.5166  -0.4292  -0.1357 ]
  [ 1.166    0.4116   1.962   ... -0.946    0.987   -0.485  ]]

 [[-0.3884   0.02295 -0.05225 ... -0.49    -0.3066   0.0674 ]
  [ 1.456    0.1726  -1.575   ... -0.665   -0.65    -0.3667 ]
  [ 0.8887  -0.2103  -0.275   ...  0.1096  -1.028   -0.496  ]
  ...
  [ 0.3267   0.1011   0.45    ... -0.579   -0.44

In [31]:
print(embeds.pooler_output.shape)
print(embeds.last_hidden_state.shape)

(1, 768)
(1, 77, 768)


In [19]:
def encode(input_ids: jnp.array, attention_mask: jnp.array):
    return text_encoder(input_ids, attention_mask=attention_mask)[0]

def create_key(seed=0):
    return jax.random.PRNGKey(seed)

rng = create_key(0)
model = jax.pmap(encode, in_axes=(0, 0,))

In [29]:
print(jnp.array([tokens['input_ids'] for i in range(DEVICES)]).shape)

(4, 64, 77)


In [27]:
init_time = time.time()
for i in range(REPEATS):
    output = model(
        jnp.array([tokens['input_ids'] for i in range(DEVICES)]),
        jnp.array([tokens['attention_mask'] for i in range(DEVICES)]),
    )
    output.block_until_ready()
print(f"Time taken: {(time.time() - init_time) / DEVICES / BATCH_SIZE / REPEATS} seconds")

Time taken: 0.0952918529510498 seconds


In [20]:
print(tokens)

{'input_ids': array([[49406,  3306,  3431,   537, 11692,   721,   649,   631,  1245,
          531,   678, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
        49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
        49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
        49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
        49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
        49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
        49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
        49407, 49407, 49407, 49407, 49407],
       [49406,  3306,  3431,   537, 11692,   721,   649,   631,  1245,
          531,   678, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
        49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
        49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
        49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
        49407, 4940