## DALL-E-mini inference
This notebook is a copy of: https://github.com/borisdayma/dalle-mini/blob/main/tools/inference/inference_pipeline.ipynb

The goal is to try out text2img inference and test its feasibility for computing a language drift metric.

In [1]:
%load_ext memory_profiler

In [2]:
# Install required libraries
!pip install -q dalle-mini
!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git

In [3]:
# Model references

# dalle-mega
# DALLE_MODEL = "dalle-mini/dalle-mini/mega-1-fp16:latest"  # can be wandb artifact or 🤗 Hub or local folder or google bucket
DALLE_COMMIT_ID = None

# if the notebook crashes too often you can use dalle-mini instead by uncommenting below line
DALLE_MODEL = "flax-community/dalle-mini"#"flax-community/dalle-mini" #"dalle-mini/dalle-mini/mini-1:v0"

# VQGAN model
VQGAN_REPO = "dalle-mini/vqgan_imagenet_f16_16384"
VQGAN_COMMIT_ID = "e93a26e7707683d349bf5d5c41c5b0ef69b677a9"



In [4]:
import jax
import jax.numpy as jnp

# check how many devices are available
jax.local_device_count()



1

In [12]:
# from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

# tokenizer = AutoTokenizer.from_pretrained(DALLE_MODEL, from_flax=True)

# model = AutoModelForSeq2SeqLM.from_pretrained(DALLE_MODEL, from_flax=True)

ValueError: Flax checkpoint seems to be incorrect. Weight ('final_logits_bias',) was expected to be of shape torch.Size([1, 50264]), but is (1, 16385).

In [None]:
# https://huggingface.co/spaces/flax-community/dalle-mini/blob/99a1ff5bc66b8a85a91e9505e2f61d8080dd7360/demo/CustomBARTv4b_model-generate.ipynb

In [11]:
"# TODO: set those args in a config file\n",
OUTPUT_VOCAB_SIZE = 16384 + 1  # encoded image token space + 1 for bos\n",
OUTPUT_LENGTH = 256 + 1  # number of encoded tokens + 1 for bos\n",
BOS_TOKEN_ID = 16384
BASE_MODEL = 'facebook/bart-large'

In [9]:
"import jax\n",
import flax.linen as nn

from transformers.models.bart.modeling_flax_bart import *
from transformers import BartTokenizer, FlaxBartForConditionalGeneration

class CustomFlaxBartModule(FlaxBartModule):
    def setup(self):
        # we keep shared to easily load pre-trained weights\n",
        self.shared = nn.Embed(
            self.config.vocab_size,
            self.config.d_model,
            embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
            dtype=self.dtype,
        ),
        # a separate embedding is used for the decoder\n",
        self.decoder_embed = nn.Embed(
            OUTPUT_VOCAB_SIZE,
            self.config.d_model,
            embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
            dtype=self.dtype,
        )
        self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared),
        # the decoder has a different config\n",
        decoder_config = BartConfig(self.config.to_dict())
        decoder_config.max_position_embeddings = OUTPUT_LENGTH
        decoder_config.vocab_size = OUTPUT_VOCAB_SIZE
        self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)

class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):
    def setup(self):
        self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
        self.lm_head = nn.Dense(
            OUTPUT_VOCAB_SIZE,
            use_bias=False,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
        )
        self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, OUTPUT_VOCAB_SIZE))
class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):
    module_class = CustomFlaxBartForConditionalGenerationModule

In [12]:
"# load pre-trained model for encoder weights\n",
base_model = FlaxBartForConditionalGeneration.from_pretrained(BASE_MODEL)

Downloading:   0%|          | 0.00/1.59k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/775M [00:00<?, ?B/s]

Some weights of FlaxBartForConditionalGeneration were not initialized from the model checkpoint at facebook/bart-large and are newly initialized: {('final_logits_bias',)}
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [17]:
"# set up our new model config\n",
config = BartConfig.from_pretrained(BASE_MODEL)
config.tie_word_embeddings = False
config.decoder_start_token_id = BOS_TOKEN_ID
config.bos_token_id = BOS_TOKEN_ID  # should not be used\n",
config.pos_token_id = BOS_TOKEN_ID  # should not be used\n",
config.eos_token_id = None  # prevents generation from stopping until we reach max_length"

In [15]:
config

BartConfig {
  "activation_dropout": 0.1,
  "activation_function": "gelu",
  "add_bias_logits": false,
  "add_final_layer_norm": false,
  "architectures": [
    "BartModel"
  ],
  "attention_dropout": 0.1,
  "bos_token_id": 0,
  "classif_dropout": 0.1,
  "classifier_dropout": 0.0,
  "d_model": 1024,
  "decoder_attention_heads": 16,
  "decoder_ffn_dim": 4096,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 12,
  "decoder_start_token_id": 16384,
  "dropout": 0.1,
  "early_stopping": true,
  "encoder_attention_heads": 16,
  "encoder_ffn_dim": 4096,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 12,
  "eos_token_id": 2,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "gradient_checkpointing": false,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1",
    "2": "LABEL_2"
  },
  "init_std": 0.02,
  "is_encoder_decoder": true,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1,
    "LABEL_2": 2
  },
  "max_position_embeddings": 1024,
  "model_type": "bart",
  "no_repeat_ngram_si

In [18]:
      "# create our model and initialize it randomly\n",
model = CustomFlaxBartForConditionalGeneration(config)

TypeError: 'tuple' object is not callable

In [16]:
 "# use pretrained weights\n",
model.params['model']['encoder'] = base_model.params['model']['encoder']
model.params['model']['shared'] = base_model.params['model']['shared']

NameError: name 'model' is not defined

In [None]:
"# we verify that the shape has not been modified\n",
model.params['final_logits_bias'].shape

In [None]:
## inference

In [None]:
tokenizer = BartTokenizer.from_pretrained(BASE_MODEL)

In [None]:
text = "My friends are cool but they eat too many carbs."
inputs = tokenizer(text, max_length=1024, return_tensors='jax')
encoder_outputs = model.encode(**inputs)

In [None]:
decoder_start_token_id = model.config.decoder_start_token_id
decoder_start_token_id

In [None]:
decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
outputs = model.decode(decoder_input_ids, encoder_outputs)

In [None]:
input_ids_test = tokenizer.encode('I enjoy walking with my cute dog', return_tensors='jax')

In [None]:
greedy_output = model.generate(input_ids_test, max_length=50)

In [None]:
## copied from repo below 

In [6]:
from transformers.models.bart.modeling_flax_bart import (
    FlaxBartAttention,
    FlaxBartForConditionalGeneration,
    FlaxBartForConditionalGenerationModule,
    FlaxBartModule,
)

In [7]:
class FlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):
    """
    Edits:
    - no bias
    - lm_head set to image_vocab_size + 1 (for BOS)
    - uses custom FlaxBartModule
    """

    def setup(self):
        self.model = FlaxBartModule(config=self.config, dtype=self.dtype)
        self.lm_head = nn.Dense(
            self.config.image_vocab_size
            + 1,  # image vocab size + 1 for BOS to have same size as decoder inputs (for sharding)
            use_bias=False,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(self.config.init_std),
        )

    def __call__(
        self,
        input_ids,
        attention_mask,
        decoder_input_ids,
        decoder_attention_mask,
        position_ids,
        decoder_position_ids,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
        deterministic: bool = True,
    ):
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            position_ids=position_ids,
            decoder_position_ids=decoder_position_ids,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            deterministic=deterministic,
        )

        hidden_states = outputs[0]

        if self.config.tie_word_embeddings:
            shared_embedding = self.model.variables["params"]["shared"]["embedding"]
            lm_logits = self.lm_head.apply(
                {"params": {"kernel": shared_embedding.T}}, hidden_states
            )
        else:
            lm_logits = self.lm_head(hidden_states)

        if not return_dict:
            output = (lm_logits,) + outputs[1:]
            return output

        return FlaxSeq2SeqLMOutput(
            logits=lm_logits,
            decoder_hidden_states=outputs.decoder_hidden_states,
            decoder_attentions=outputs.decoder_attentions,
            cross_attentions=outputs.cross_attentions,
            encoder_last_hidden_state=outputs.encoder_last_hidden_state,
            encoder_hidden_states=outputs.encoder_hidden_states,
            encoder_attentions=outputs.encoder_attentions,
        )

In [8]:
# model seems to be 1.64GB big. not sure which one tho
# Load models & tokenizer
from dalle_mini import DalleBart, DalleBartProcessor#, FlaxBartForConditionalGeneration
from vqgan_jax.modeling_flax_vqgan import VQModel
from transformers import CLIPProcessor, FlaxCLIPModel

# Load dalle-mini
model, params = FlaxBartForConditionalGenerationModule.from_pretrained(
    DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=jnp.float16#, ignore_mismatched_sizes=True #, _do_init=False
)

# Load VQGAN
vqgan, vqgan_params = VQModel.from_pretrained(
    VQGAN_REPO, revision=VQGAN_COMMIT_ID, _do_init=False
)



AttributeError: type object 'FlaxBartForConditionalGenerationModule' has no attribute 'from_pretrained'

In [None]:
# Model parameters are replicated on each device for faster inference.

from flax.jax_utils import replicate

params = replicate(params)
vqgan_params = replicate(vqgan_params)

# Model functions are compiled and parallelized to take advantage of multiple devices.
from functools import partial

# model inference
@partial(jax.pmap, axis_name="batch", static_broadcasted_argnums=(3, 4, 5, 6))
def p_generate(
    tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale
):
    return model.generate(
        **tokenized_prompt,
        prng_key=key,
        params=params,
        top_k=top_k,
        top_p=top_p,
        temperature=temperature,
        condition_scale=condition_scale,
    )


# decode image
@partial(jax.pmap, axis_name="batch")
def p_decode(indices, params):
    return vqgan.decode_code(indices, params=params)

In [None]:
# Keys are passed to the model on each device to generate unique inference per device.
import random

# create a random key
seed = random.randint(0, 2**32 - 1)
key = jax.random.PRNGKey(seed)

In [None]:
# Text
from dalle_mini import DalleBartProcessor

processor = DalleBartProcessor.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)


In [None]:
prompts = [
    "sunset over a lake in the mountains",
    "the Eiffel tower landing on the moon",
]


In [None]:


tokenized_prompts = processor(prompts)



In [None]:
tokenized_prompt = replicate(tokenized_prompts)

In [None]:

# We generate images using dalle-mini model and decode them with the VQGAN.

# number of predictions per prompt
n_predictions = 8

# We can customize generation parameters (see https://huggingface.co/blog/how-to-generate)
gen_top_k = None
gen_top_p = None
temperature = None
cond_scale = 10.0

In [None]:


from flax.training.common_utils import shard_prng_key
import numpy as np
from PIL import Image
from tqdm.notebook import trange

print(f"Prompts: {prompts}\n")
# generate images
images = []
for i in trange(max(n_predictions // jax.device_count(), 1)):
    # get a new key
    key, subkey = jax.random.split(key)
    # generate images
    encoded_images = p_generate(
        tokenized_prompt,
        shard_prng_key(subkey),
        params,
        gen_top_k,
        gen_top_p,
        temperature,
        cond_scale,
    )
    # remove BOS
    encoded_images = encoded_images.sequences[..., 1:]
    # decode images
    decoded_images = p_decode(encoded_images, vqgan_params)
    decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))
    for decoded_img in decoded_images:
        img = Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8))
        images.append(img)
        display(img)
        print()

