# Training a Stable Diffusion Model with Texture Inversion

**Objectives**: Train a diffusion Model with a few Bliss symbol images to associate them with a token `<bliss-symbol>`. The model is trained with Flax/JAX.

**Parameters**: All parameters are set at step 2.2. Refer to [Example of texture inversion training with Flax/JAX](https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion#training-with-flaxjax) about what they are. 

**Note**: If using Azure to run this job, the sensitive information for connecting to the Azure subscription in the
section 1 needs to be filled in before running. If using other platform, this section should be replaced with the
credential verification for that platform.

**References**
* [Using Textual Inversion Embeddings to gain substantial control over your generated images](https://blog.paperspace.com/dreambooth-stable-diffusion-tutorial-part-2-textual-inversion/)
* [How to Fine-tune Stable Diffusion using Textual Inversion](https://towardsdatascience.com/how-to-fine-tune-stable-diffusion-using-textual-inversion-b995d7ecc095)
* [Example of texture inversion training](https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion)


### 1. Connect to Azure Machine Learning workspace

Before we dive in the code, you'll need to connect to your workspace. The workspace is the top-level resource for Azure Machine Learning, providing a centralized place to work with all the artifacts you create when you use Azure Machine Learning.

We are using `DefaultAzureCredential` to get access to workspace. `DefaultAzureCredential` should be capable of handling most scenarios. If you want to learn more about other available credentials, go to [set up authentication doc](https://learn.microsoft.com/en-us/azure/machine-learning/how-to-setup-authentication?tabs=sdk), [azure-identity reference doc](https://learn.microsoft.com/en-us/python/api/azure-identity/azure.identity?view=azure-python).

**Make sure to enter your workspace credentials before you run the script below.**

In [None]:
# Handle to the workspace
from azure.ai.ml import MLClient

# Authentication package
from azure.identity import DefaultAzureCredential

credential = DefaultAzureCredential()

# Get a handle to the workspace. You can find the info on the workspace tab on ml.azure.com
ml_client = MLClient(
    credential=credential,
    subscription_id="...",  # this will look like xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx
    resource_group_name="...",
    workspace_name="...",
)

### 2. Settings

#### 2.1 import required libs

In [None]:
import argparse
import logging
import math
import os
import random
from pathlib import Path
from typing import Optional

import mlflow
import mlflow.sklearn

import jax
import jax.numpy as jnp
import numpy as np
import optax
import PIL
import torch
import torch.utils.checkpoint
import transformers
from flax import jax_utils
from flax.training import train_state
from flax.training.common_utils import shard
from huggingface_hub import HfFolder, Repository, create_repo, whoami

# TODO: remove and import from diffusers.utils when the new version of diffusers is released
from packaging import version
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel, set_seed

from diffusers import (
    FlaxAutoencoderKL,
    FlaxDDPMScheduler,
    FlaxPNDMScheduler,
    FlaxStableDiffusionPipeline,
    FlaxUNet2DConditionModel,
)
from diffusers.pipelines.stable_diffusion import FlaxStableDiffusionSafetyChecker
from diffusers.utils import check_min_version

if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
    PIL_INTERPOLATION = {
        "linear": PIL.Image.Resampling.BILINEAR,
        "bilinear": PIL.Image.Resampling.BILINEAR,
        "bicubic": PIL.Image.Resampling.BICUBIC,
        "lanczos": PIL.Image.Resampling.LANCZOS,
        "nearest": PIL.Image.Resampling.NEAREST,
    }
else:
    PIL_INTERPOLATION = {
        "linear": PIL.Image.LINEAR,
        "bilinear": PIL.Image.BILINEAR,
        "bicubic": PIL.Image.BICUBIC,
        "lanczos": PIL.Image.LANCZOS,
        "nearest": PIL.Image.NEAREST,
    }


#### 2.2 set up constants

In [None]:
pretrained_model_name_or_path="duongna/stable-diffusion-v1-4-flax"
train_data_dir='./data/texture_inversion'
learnable_property="object"
placeholder_token="<bliss-symbol>"
initializer_token='bliss'
seed=42
resolution=128
train_batch_size=1
num_train_epochs=100
max_train_steps=3000
save_steps=500
scale_lr=True
learning_rate=5.0e-04
lr_warmup_steps=500
repeats=100
center_crop=False
lr_scheduler="constant"
adam_beta1=0.9
adam_beta2=0.999
adam_weight_decay=1e-2
adam_epsilon=1e-08
output_dir='./output'

# Setup the prompt templates for training
imagenet_templates_small = [
    "a photo of a {}",
    "a rendering of a {}",
    "a cropped photo of the {}",
    "the photo of a {}",
    "a photo of a clean {}",
    "a photo of a dirty {}",
    "a dark photo of the {}",
    "a photo of my {}",
    "a photo of the cool {}",
    "a close-up photo of a {}",
    "a bright photo of the {}",
    "a cropped photo of a {}",
    "a photo of the {}",
    "a good photo of the {}",
    "a photo of one {}",
    "a close-up photo of the {}",
    "a rendition of the {}",
    "a photo of the clean {}",
    "a rendition of a {}",
    "a photo of a nice {}",
    "a good photo of a {}",
    "a photo of the nice {}",
    "a photo of the small {}",
    "a photo of the weird {}",
    "a photo of the large {}",
    "a photo of a cool {}",
    "a photo of a small {}",
]

imagenet_style_templates_small = [
    "a painting in the style of {}",
    "a rendering in the style of {}",
    "a cropped painting in the style of {}",
    "the painting in the style of {}",
    "a clean painting in the style of {}",
    "a dirty painting in the style of {}",
    "a dark painting in the style of {}",
    "a picture in the style of {}",
    "a cool painting in the style of {}",
    "a close-up painting in the style of {}",
    "a bright painting in the style of {}",
    "a cropped painting in the style of {}",
    "a good painting in the style of {}",
    "a close-up painting in the style of {}",
    "a rendition in the style of {}",
    "a nice painting in the style of {}",
    "a small painting in the style of {}",
    "a weird painting in the style of {}",
    "a large painting in the style of {}",
]


#### 2.3 set up helper functions

In [None]:
class TextualInversionDataset(Dataset):
    def __init__(
        self,
        data_root,
        tokenizer,
        learnable_property="object",  # [object, style]
        size=512,
        repeats=100,
        interpolation="bicubic",
        flip_p=0.5,
        set="train",
        placeholder_token="*",
        center_crop=False,
    ):
        self.data_root = data_root
        self.tokenizer = tokenizer
        self.learnable_property = learnable_property
        self.size = size
        self.placeholder_token = placeholder_token
        self.center_crop = center_crop
        self.flip_p = flip_p

        self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root) if file_path.endswith('.png')]
        print("self.image_paths: ", self.image_paths)

        self.num_images = len(self.image_paths)
        self._length = self.num_images

        if set == "train":
            self._length = self.num_images * repeats

        self.interpolation = {
            "linear": PIL_INTERPOLATION["linear"],
            "bilinear": PIL_INTERPOLATION["bilinear"],
            "bicubic": PIL_INTERPOLATION["bicubic"],
            "lanczos": PIL_INTERPOLATION["lanczos"],
        }[interpolation]

        self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
        self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)

    def __len__(self):
        return self._length

    def __getitem__(self, i):
        example = {}
        image = Image.open(self.image_paths[i % self.num_images])

        if not image.mode == "RGB":
            image = image.convert("RGB")

        placeholder_string = self.placeholder_token
        text = random.choice(self.templates).format(placeholder_string)

        example["input_ids"] = self.tokenizer(
            text,
            padding="max_length",
            truncation=True,
            max_length=self.tokenizer.model_max_length,
            return_tensors="pt",
        ).input_ids[0]

        # default to score-sde preprocessing
        img = np.array(image).astype(np.uint8)

        if self.center_crop:
            crop = min(img.shape[0], img.shape[1])
            (
                h,
                w,
            ) = (
                img.shape[0],
                img.shape[1],
            )
            img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]

        image = Image.fromarray(img)
        image = image.resize((self.size, self.size), resample=self.interpolation)

        image = self.flip_transform(image)
        image = np.array(image).astype(np.uint8)
        image = (image / 127.5 - 1.0).astype(np.float32)

        example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
        return example

def resize_token_embeddings(model, new_num_tokens, initializer_token_id, placeholder_token_id, rng):
    if model.config.vocab_size == new_num_tokens or new_num_tokens is None:
        return
    model.config.vocab_size = new_num_tokens

    params = model.params
    old_embeddings = params["text_model"]["embeddings"]["token_embedding"]["embedding"]
    old_num_tokens, emb_dim = old_embeddings.shape

    initializer = jax.nn.initializers.normal()

    new_embeddings = initializer(rng, (new_num_tokens, emb_dim))
    new_embeddings = new_embeddings.at[:old_num_tokens].set(old_embeddings)
    new_embeddings = new_embeddings.at[placeholder_token_id].set(new_embeddings[initializer_token_id])
    params["text_model"]["embeddings"]["token_embedding"]["embedding"] = new_embeddings

    model.params = params
    return model


def get_params_to_save(params):
    return jax.device_get(jax.tree_util.tree_map(lambda x: x[0], params))


### 3 Train

#### 3.1 load tokenizer

In [None]:
# Load the tokenizer and add the placeholder token as a additional special token
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")

# Add the placeholder token in tokenizer
num_added_tokens = tokenizer.add_tokens(placeholder_token)

print('num_added_tokens: ', num_added_tokens)

#### 3.2 Convert the initializer_token, placeholder_token to ids

In [None]:
# Convert the initializer_token, placeholder_token to ids
token_ids = tokenizer.encode(initializer_token, add_special_tokens=False)
print("token_ids: ", token_ids)

# Check if initializer_token is a single token or a sequence of tokens
if len(token_ids) > 1:
    raise ValueError("The initializer token must be a single token.")

initializer_token_id = token_ids[0]
placeholder_token_id = tokenizer.convert_tokens_to_ids(placeholder_token)

print("initializer_token_id: ", initializer_token_id)
print("placeholder_token_id: ", placeholder_token_id)

#### 3.3 load models

In [None]:
# Load models and create wrapper for stable diffusion
text_encoder = FlaxCLIPTextModel.from_pretrained(
    pretrained_model_name_or_path, subfolder="text_encoder"
)
vae, vae_params = FlaxAutoencoderKL.from_pretrained(
    pretrained_model_name_or_path, subfolder="vae"
)
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
    pretrained_model_name_or_path, subfolder="unet"
)

#### 3.4 Create sampling RNG and train dataset

In [None]:
# Create sampling rng
rng = jax.random.PRNGKey(seed)
rng, _ = jax.random.split(rng)
# Resize the token embeddings as we are adding new special tokens to the tokenizer
text_encoder = resize_token_embeddings(
    text_encoder, len(tokenizer), initializer_token_id, placeholder_token_id, rng
)
original_token_embeds = text_encoder.params["text_model"]["embeddings"]["token_embedding"]["embedding"]

train_dataset = TextualInversionDataset(
    data_root=train_data_dir,
    tokenizer=tokenizer,
    size=resolution,
    placeholder_token=placeholder_token,
    repeats=repeats,
    learnable_property=learnable_property,
    center_crop=center_crop,
    set="train",
)

def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    input_ids = torch.stack([example["input_ids"] for example in examples])

    batch = {"pixel_values": pixel_values, "input_ids": input_ids}
    batch = {k: v.numpy() for k, v in batch.items()}

    return batch

total_train_batch_size = train_batch_size * jax.local_device_count()
train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=total_train_batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn
)


#### 3.5 Optimization

In [None]:
learning_rate = learning_rate * total_train_batch_size

constant_scheduler = optax.constant_schedule(learning_rate)

optimizer = optax.adamw(
    learning_rate=constant_scheduler,
    b1=adam_beta1,
    b2=adam_beta2,
    eps=adam_epsilon,
    weight_decay=adam_weight_decay,
)

def create_mask(params, label_fn):
    def _map(params, mask, label_fn):
        for k in params:
            if label_fn(k):
                mask[k] = "token_embedding"
            else:
                if isinstance(params[k], dict):
                    mask[k] = {}
                    _map(params[k], mask[k], label_fn)
                else:
                    mask[k] = "zero"

    mask = {}
    _map(params, mask, label_fn)
    return mask

def zero_grads():
    # from https://github.com/deepmind/optax/issues/159#issuecomment-896459491
    def init_fn(_):
        return ()

    def update_fn(updates, state, params=None):
        return jax.tree_util.tree_map(jnp.zeros_like, updates), ()

    return optax.GradientTransformation(init_fn, update_fn)

# Zero out gradients of layers other than the token embedding layer
tx = optax.multi_transform(
    {"token_embedding": optimizer, "zero": zero_grads()},
    create_mask(text_encoder.params, lambda s: s == "token_embedding"),
)

state = train_state.TrainState.create(apply_fn=text_encoder.__call__, params=text_encoder.params, tx=tx)

noise_scheduler = FlaxDDPMScheduler(
    beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000
)
noise_scheduler_state = noise_scheduler.create_state()


#### 3.6 Train

In [None]:
# Initialize our training
train_rngs = jax.random.split(rng, jax.local_device_count())

# Define gradient train step fn
def train_step(state, vae_params, unet_params, batch, train_rng):
    dropout_rng, sample_rng, new_train_rng = jax.random.split(train_rng, 3)

    def compute_loss(params):
        vae_outputs = vae.apply(
            {"params": vae_params}, batch["pixel_values"], deterministic=True, method=vae.encode
        )
        latents = vae_outputs.latent_dist.sample(sample_rng)
        # (NHWC) -> (NCHW)
        latents = jnp.transpose(latents, (0, 3, 1, 2))
        latents = latents * vae.config.scaling_factor

        noise_rng, timestep_rng = jax.random.split(sample_rng)
        noise = jax.random.normal(noise_rng, latents.shape)
        bsz = latents.shape[0]
        timesteps = jax.random.randint(
            timestep_rng,
            (bsz,),
            0,
            noise_scheduler.config.num_train_timesteps,
        )
        noisy_latents = noise_scheduler.add_noise(noise_scheduler_state, latents, noise, timesteps)
        encoder_hidden_states = state.apply_fn(
            batch["input_ids"], params=params, dropout_rng=dropout_rng, train=True
        )[0]
        # Predict the noise residual and compute loss
        model_pred = unet.apply(
            {"params": unet_params}, noisy_latents, timesteps, encoder_hidden_states, train=False
        ).sample

        # Get the target for loss depending on the prediction type
        if noise_scheduler.config.prediction_type == "epsilon":
            target = noise
        elif noise_scheduler.config.prediction_type == "v_prediction":
            target = noise_scheduler.get_velocity(noise_scheduler_state, latents, noise, timesteps)
        else:
            raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

        loss = (target - model_pred) ** 2
        loss = loss.mean()

        return loss

    grad_fn = jax.value_and_grad(compute_loss)
    loss, grad = grad_fn(state.params)
    grad = jax.lax.pmean(grad, "batch")
    new_state = state.apply_gradients(grads=grad)

    # Keep the token embeddings fixed except the newly added embeddings for the concept,
    # as we only want to optimize the concept embeddings
    token_embeds = original_token_embeds.at[placeholder_token_id].set(
        new_state.params["text_model"]["embeddings"]["token_embedding"]["embedding"][placeholder_token_id]
    )
    new_state.params["text_model"]["embeddings"]["token_embedding"]["embedding"] = token_embeds

    metrics = {"loss": loss}
    metrics = jax.lax.pmean(metrics, axis_name="batch")
    return new_state, metrics, new_train_rng

# Create parallel version of the train and eval step
p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))

# Replicate the train state on each device
state = jax_utils.replicate(state)
vae_params = jax_utils.replicate(vae_params)
unet_params = jax_utils.replicate(unet_params)

# Train!
num_update_steps_per_epoch = math.ceil(len(train_dataloader))

# Scheduler and math around the number of training steps.
if max_train_steps is None:
    max_train_steps = num_train_epochs * num_update_steps_per_epoch

num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)

print("***** Running training *****")
print(f"  Num examples = {len(train_dataset)}")
print(f"  Num Epochs = {num_train_epochs}")
print(f"  Instantaneous batch size per device = {train_batch_size}")
print(f"  Total train batch size (w. parallel & distributed) = {total_train_batch_size}")
print(f"  Total optimization steps = {max_train_steps}")

global_step = 0

epochs = tqdm(range(num_train_epochs), desc=f"Epoch ... (1/{num_train_epochs})", position=0)
for epoch in epochs:
    # ======================== Training ================================

    train_metrics = []

    steps_per_epoch = len(train_dataset) // total_train_batch_size
    train_step_progress_bar = tqdm(total=steps_per_epoch, desc="Training...", position=1, leave=False)
    # train
    for batch in train_dataloader:
        batch = shard(batch)
        state, train_metric, train_rngs = p_train_step(state, vae_params, unet_params, batch, train_rngs)
        train_metrics.append(train_metric)

        train_step_progress_bar.update(1)
        global_step += 1

        if global_step >= max_train_steps:
            break
        if global_step % save_steps == 0:
            learned_embeds = get_params_to_save(state.params)["text_model"]["embeddings"]["token_embedding"][
                "embedding"
            ][placeholder_token_id]
            learned_embeds_dict = {placeholder_token: learned_embeds}
            jnp.save(
                os.path.join(output_dir, "learned_embeds-" + str(global_step) + ".npy"), learned_embeds_dict
            )

    train_metric = jax_utils.unreplicate(train_metric)

    train_step_progress_bar.close()
    epochs.write(f"Epoch... ({epoch + 1}/{num_train_epochs} | Loss: {train_metric['loss']})")


#### 3.7 Create the pipeline using the trained modules and save it.

In [None]:
if jax.process_index() == 0:
    scheduler = FlaxPNDMScheduler(
        beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
    )
    safety_checker = FlaxStableDiffusionSafetyChecker.from_pretrained(
        "CompVis/stable-diffusion-safety-checker", from_pt=True
    )
    pipeline = FlaxStableDiffusionPipeline(
        text_encoder=text_encoder,
        vae=vae,
        unet=unet,
        tokenizer=tokenizer,
        scheduler=scheduler,
        safety_checker=safety_checker,
        feature_extractor=CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32"),
    )

    pipeline.save_pretrained(
        output_dir,
        params={
            "text_encoder": get_params_to_save(state.params),
            "vae": get_params_to_save(vae_params),
            "unet": get_params_to_save(unet_params),
            "safety_checker": safety_checker.params,
        },
    )

    # Also save the newly trained embeddings
    learned_embeds = get_params_to_save(state.params)["text_model"]["embeddings"]["token_embedding"]["embedding"][
        placeholder_token_id
    ]
    learned_embeds_dict = {placeholder_token: learned_embeds}
    jnp.save(os.path.join(output_dir, "flax_learned_embeds.npy"), learned_embeds_dict)


### 4. Use the fine-tuned model

In [None]:
from diffusers import StableDiffusionPipeline

model_id = "./output"
pipe = StableDiffusionPipeline.from_pretrained(model_id,from_flax=True,safety_checker=None)

# prompt = "A <bliss-symbol> on a backpack"
prompt = "A new <bliss-symbol> for microcosm"

image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]

# image.save("./data/texture_inversion_fineTuned_results/bliss-on-backpack.png")
image.save("./data/texture_inversion_fineTuned_results/bliss-for-microcosm.png")
