# DallE-Mini Text to Image Project

This notebook will walk through the set up of dalle-mini, jax, etc. without using using cache requirements of hugging-face, weights, or bias. Thereby allowing you to run this application on any machine and generate images from given prompts under runtime support.


## Environment set up and verification



### GPU Set up

Import Jax

In [None]:
import jax
import jax.numpy as jnp

In Google Colab select Runtime > Change Runtime Type> GPU. Validate if GPU is loaded properly

In [None]:
jax.local_device_count()

1

In [None]:
jax.devices()

[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]

## Install and Import all AI model components

### Install DALLE-mini

In [None]:
!pip install -q dalle-mini

[K     |████████████████████████████████| 5.5 MB 59.3 MB/s 
[K     |████████████████████████████████| 53 kB 1.9 MB/s 
[K     |████████████████████████████████| 235 kB 69.6 MB/s 
[K     |████████████████████████████████| 41 kB 506 kB/s 
[K     |████████████████████████████████| 1.9 MB 59.4 MB/s 
[K     |████████████████████████████████| 189 kB 70.1 MB/s 
[K     |████████████████████████████████| 240 kB 74.6 MB/s 
[K     |████████████████████████████████| 154 kB 65.6 MB/s 
[K     |████████████████████████████████| 237 kB 73.2 MB/s 
[K     |████████████████████████████████| 8.3 MB 28.7 MB/s 
[K     |████████████████████████████████| 51 kB 6.6 MB/s 
[K     |████████████████████████████████| 85 kB 4.6 MB/s 
[K     |████████████████████████████████| 7.6 MB 59.8 MB/s 
[K     |████████████████████████████████| 182 kB 68.4 MB/s 
[K     |████████████████████████████████| 182 kB 69.1 MB/s 
[K     |████████████████████████████████| 168 kB 68.4 MB/s 
[K     |███████████████████████

### Install Jax's version of the VQGan

In [None]:
!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git

  Building wheel for vqgan-jax (setup.py) ... [?25l[?25hdone


Import all model components

In [None]:
from dalle_mini import DalleBart, DalleBartProcessor
#Jax binding for vqgan
from vqgan_jax.modeling_flax_vqgan import VQModel
from transformers import CLIPProcessor, FlaxCLIPModel

### Import Hugging Face Library - already installed on Google Colab

In [None]:
from huggingface_hub import hf_hub_url, cached_download, hf_hub_download

#### Download all DallE-mini files from hugging face.  
Files are listed at the following link:

https://huggingface.co/dalle-mini/dalle-mini/tree/main

Note: All files will be saved locally.  Note: create new folders in google colab structure under content>dalle-mini>vqgan

In [None]:
import shutil

In [None]:
dalle_mini_file_list = ['config.json', 'enwiki-words-frequency.txt', 
                         'flax_model.msgpack', 'merges.txt', 
                         'special_tokens_map.json', 'tokenizer.json', 
                         'tokenizer_config.json', 'vocab.json']

In [None]:
for file in dalle_mini_file_list:
  # downloaded to local hugging face cache folder
  downloaded_file = hf_hub_download('dalle-mini/dalle-mini', filename = file)
  target_path = '/content/dalle-mini/' + file
  # copies files from the cache folder to the local dalle-mini folder
  # removes dependency on library cache - allows us to run anywhere
  # (e.g. refernce from our API)
  shutil.copy(downloaded_file, target_path)


Downloading:   0%|          | 0.00/1.33k [00:00<?, ?B/s]

FileNotFoundError: ignored

Validate files successfully downloaded and we are using local file versions (not symlinks - when loaded models, symlinks do not work)

In [None]:
!ls -lah /content/dalle-mini

### Retrieve all VQGAN files from hugging face
Files are listed at the following link: 

https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384/tree/main

In [None]:
vqgan_file_list = ['config.json', 'flax_model.msgpack']

In [None]:
for file in vqgan_file_list:
  # downloaded to local hugging face cache folder
  downloaded_file = hf_hub_download('dalle-mini/vqgan_imagenet_f16_16384', filename = file)
  target_path = '/content/dalle-mini/vqgan/' + file
  # copies files from the cache folder to the local dalle-mini folder
  # removes dependency on library cache - allows us to run anywhere
  # (e.g. refernce from our API)
  shutil.copy(downloaded_file, target_path)

Validate files successfully downloaded and we are using local file versions

In [None]:
!ls -lah /content/dalle-mini/vqgan

## Load Models

### Load DallE-mini model

uses flax_model.msgpack and config.json files

In [None]:
DALLE_MODEL_PATH = '/content/dalle-mini'
DALLE_COMMIT_ID = None
dalle_model, dalle_params = DalleBart.from_pretrained(
    DALLE_MODEL_PATH, revision = DALLE_COMMIT_ID, dtype=jnp.float16, _do_init=False,
)
# ensure model is not initialized: _do_init

Validate DallE-mini Model

In [None]:
dalle_model

In [None]:
dalle_model.config

In [None]:
# View model params - comment out to save space on jupyter notebook
# dalle_params

### Load VQGAN Model

In [None]:
VQGAN_MODEL_PATH = '/content/dalle-mini/vqgan'
VQGAN_COMMIT_ID = None
vqgan_model, vqgan_params = VQModel.from_pretrained(
    VQGAN_MODEL_PATH, revision = VQGAN_COMMIT_ID, dtype=jnp.float16, _do_init=False,
)

Validate VQGAN Model

In [None]:
vqgan_model.config

In [None]:
#vqgan_params #comment out to save space on jupyter notebook

### Load DallE Bart Processor

uses downloaded files related with tokenization

In [None]:
# use the same paths as dalle-mini
DALLE_MODEL_PATH = '/content/dalle-mini'
DALLE_COMMIT_ID = None
dalle_bart_processor = DalleBartProcessor.from_pretrained(
    DALLE_MODEL_PATH, revision = DALLE_COMMIT_ID)



Validate DallE Bart Processor

In [None]:
dalle_bart_processor

## Multi-GPU Set-up

This set up is optional depending on what you have available to you. But a multi-GPU implementation is activated with replication.  Specifically, it will replicate parameters on all available devices.  

Note: If ran on google collab, only one gpu is available.

In [None]:
from flax.jax_utils import replicate
params = replicate(dalle_params)
vqgan_params = replicate(vqgan_params)

## Model Inference
### Encode text to Images

Use partial class to parallelize given functions

In [None]:
from functools import partial

Use the main DallE-mini model to generate or  encode images i.e. pass in the tokenized prompt and encode them into images.

In [None]:
@partial(jax.pmap, axis_name="batch", static_broadcasted_argnums =(3,4,5,6))
def p_generate(tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale):
  return dalle_model.generate(
      **tokenized_prompt, 
      prng_key=key,
      params = params,
      top_k = top_k,
      top_p = top_p,
      temperature = temperature,
      condition_scale= condition_scale,
  )

### Decode Images

In [None]:
@partial(jax.pmap, axis_name="batch")
def p_decode(indices, params):
  return vqgan_model.decode_code(indices, params=params)


### Exercise Encoder Decoder Pipeline

Create example text input prompt

In [None]:
prompt = ['vincent van gogh paintings mixed with pumpkins']
# Process prompts with bart
tokenized_prompts = dalle_bart_processor(prompt)

Distribute tokenized prompts across multiple gpu devices

In [None]:
tokenized_prompt = replicate(tokenized_prompts)
print(tokenized_prompt)

#### Defining Model Parameters

##### Random Key Parameter

In [None]:
import random

#create a random key
seed = random.randint(0, 2**32 - 1)
key = jax.random.PRNGKey(seed)

In [None]:
# number of predictions (images) per prompt
num_predictions = 4

##### Customize Generation Paramaters
* Resource: https://huggingface.co/blog/how-to-generate

In [None]:
gen_top_k = None
gen_top_p = None
temperature = None
#conditioning scale
cond_scale = 10.0


#### Generate Images

In [None]:
from flax.training.common_utils import shard_prng_key
import numpy as np
from PIL import Image
from tqdm.notebook import trange

In [None]:
print(f"Prompts: {prompt}")

In [None]:
images = []
#device count = 1
for i in trange(max(num_predictions // jax.device_count(), 1)):
  # generate a new key
  key, subkey = jax.random.split(key)

  # Encoder
  # generate images 
  encoded_images = p_generate(
      tokenized_prompt,
      shard_prng_key(subkey),
      params,
      gen_top_k,
      gen_top_p,
      temperature,
      cond_scale,
  )

  # remove beginning of sequence
  encoded_images = encoded_images.sequences[..., 1:]

  # Decoder
  # decode images
  decoded_images = p_decode(encoded_images, vqgan_params)

  # Clip method - select top images
  decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1,256, 256,3))

  # convert images to numpy array in order to display the images
  for decoded_img in decoded_images:
    img = Image.fromarray(np.asarray(decoded_img * 255, dtype = np.uint8))
    images.append(img)
    display(img)
    print()

  

If running into errors in the above cell, ensure runtime environment is reset to use GPU (not TPU or CPU).  This may reuire rerunning all code blocks in the notebook

Check runtime environment

In [None]:
!nvidia-smi

Resources: https://www.youtube.com/watch?v=uVYZR6Wab7o