# Stable Diffusion Dreambooth Token Checker

Inspired and adapted from scripts and ideas by 2kpr

https://www.reddit.com/r/StableDiffusion/comments/zc65l4/rare_tokens_for_dreambooth_training_stable/

The purpose of this notebook is to check what a given token / token+class pair might generate from the model you are planning to train on. It will also show a breakdown of how your token is tokenized.   

A 3x3 grid of images will be generated prompted with a given token from a list.  You can also include class words so that each token will be paired with the class and also added to the list.

You can add a series of token words (comma separated). 

Output will include a breakdown of your token - i.e. so that long, convoluted token you came up with might actually be broken up into subtokens, and having a very strong prior association in the model. 
 



In [None]:
#@title 1. Install Requirements

!pip install -q --upgrade diffusers[torch]
!pip install -q xformers==0.0.16
!pip install -q transformers
!pip install -q triton


# ===========================================
import os
import shutil
from os import path
from os.path import exists
import torch
from diffusers import StableDiffusionPipeline, AutoencoderKL, DDIMScheduler
import xformers
from IPython.display import clear_output 

# font
if exists('inconsolata.regular.ttf')==False:
  #!wget https://github.com/larsenwork/Gidole/raw/master/Resources/GidoleFont/Gidole-Regular.ttf
  !wget 'https://www.1001fonts.com/download/font/inconsolata.regular.ttf'

# create output dir
OUTPUT_DIR = "output"
if path.exists(OUTPUT_DIR)==False:
    os.mkdir(OUTPUT_DIR)

#MODEL_PATH = "models/stable-diffusion-v1-5" # local path
#VAE_PATH = "models/sd-vae-ft-mse" # local path

MODEL_PATH = "runwayml/stable-diffusion-v1-5"
#MODEL_PATH = "stabilityai/stable-diffusion-2-1"
VAE_PATH = "stabilityai/sd-vae-ft-mse"

if torch.cuda.is_available():
    torch.cuda.empty_cache()

pipe = StableDiffusionPipeline.from_pretrained(f'{MODEL_PATH}', vae = AutoencoderKL.from_pretrained(f'{VAE_PATH}', torch_dtype = torch.float16), revision="fp16", torch_dtype = torch.float16, safety_checker = None)

pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)

pipe = pipe.to("cuda")

pipe.enable_xformers_memory_efficient_attention()

# =========================================================================================
# Copy the vocab.json from the cache to current dir
# =========================================================================================
# tokenizer/vocab.json files are normally saved in cache e.g.
# /root/.cache/huggingface/hub/models--runwayml--stable-diffusion-v1-5/snapshots/xxxxxxxxxxxxxx/tokenizer/vocab.json

def find_tokenizer_vocab_files():
  name = 'vocab.json' 
  init_path = '/root/.cache/'
  result = []
  for root, dirs, files in os.walk(init_path):
    if name in files:
      result.append(os.path.join(root, name))
  return result

result = find_tokenizer_vocab_files()

for cached_vocab_path in result:
  print('-> ' + cached_vocab_path)

  #split the cached dir path
  dirs_list = cached_vocab_path.split('/')
  # get repo name without username
  HF_repo_name = MODEL_PATH.split('/')

  #build dir path to copy vocab to
  copy_to_dir = 'dictionary/' + HF_repo_name[1] + '/' + dirs_list[8] + '/' + dirs_list[9]


  
  # create target dir if not exist
  os.makedirs(os.path.dirname(copy_to_dir), exist_ok=True)
  #copy
  shutil.copy(cached_vocab_path, copy_to_dir)
  
  #Set var
  Vocab_File_Path = copy_to_dir

print("Finished installing requirements.") 

# ======================================================================================
# DEFINE FUNCTIONS 
# ======================================================================================

#@title 2. Define Functions
from PIL import Image
from PIL import ImageDraw
from PIL import ImageFont
from transformers import CLIPTokenizer, CLIPModel
import json

split_positions = []
image_width_height = 512
grid_rows = 3
grid_cols = 3
grid_width = image_width_height * grid_cols
myFont = ImageFont.truetype("inconsolata.regular.ttf", size=35)


# ===========================================================================
def tokenize(token):

  with open(Vocab_File_Path, "r", encoding='utf-8') as f:
    vocab = json.load(f)
    vocab = {v: k for k, v in vocab.items()}

  tokenizer = CLIPTokenizer.from_pretrained(MODEL_PATH, subfolder="tokenizer",revision="fp16")

  #inputs = tokenizer(sys.argv[1], padding=True)
  inputs = tokenizer(f'{token}', padding=True)
  value = sum(inputs["input_ids"])
  value_str = str(value)
  token_data = f'Token: {token}' + " : "

  return_list = []
  margin = 256

  # add current data string to list
  return_list.append(token_data)
  token_data = '' # reset

  for x in range(len(inputs["input_ids"])):
    #print(vocab[inputs["input_ids"][x]] + " " + str(inputs["input_ids"][x]))
      
    # don't include <|startoftext|> 49406 or <|endoftext|> 49407
    if inputs["input_ids"][x] != 49406 and inputs["input_ids"][x] != 49407:
      # add next sub-token
      token_data += '[' + vocab[inputs["input_ids"][x]] + ']'
      # check string length      
      if (get_text_size(token_data, myFont)[0] + margin) >= grid_width:
        return_list.append(token_data.strip())
        token_data = '' #reset string
      if x < len(inputs["input_ids"]) - 2:
        token_data += ' '
    
  # add last or only string
  return_list.append(token_data.strip())

  print(">>>")
  print(return_list )

  return return_list

# ===========================================================================
def generate_grid():

  # read word liste from file
  #with open("4tokens_short.txt", "r", encoding="utf-8") as f:
  #	tokenList = [line.strip() for line in f.readlines()]

  # =============================================================
  print("start")
  
  titlebar = 128

  for token in token_list:
    token_data_list = tokenize(token)
    print("Token: " + token)
    #create new image for grid
    grid = Image.new('RGB', size=(grid_cols*image_width_height, (grid_rows*image_width_height)+titlebar))
    with torch.autocast("cuda"), torch.inference_mode():
      for i in range(9):
        image = pipe(token, num_inference_steps=20, num_images_per_prompt=1).images[0]
        grid.paste(image, box=(i%grid_cols*image_width_height, (i//grid_cols*512)+titlebar))

    # Call draw Method to add 2D graphics in an image
    draw = ImageDraw.Draw(grid)
  
    # Add Text to grid
    text_height = get_text_size(token_data_list[0], myFont)[1]
    x = 0
    y = 0
    for line in token_data_list:
      draw.text((x, y), line, font=myFont, fill=(255, 255, 255))
      y += text_height

    #draw.text((2, 12), token_data, font=myFont, fill=(255, 255, 255))
  
    #print(get_text_size(token_data, myFont))

    # since quality is not a priority we will save in jpg with some compression
    # to keep filesizes small and faster loading. 
    #print("Saving grid..." + OUTPUT_DIR + "/" + token + ".png")
    #grid.save(OUTPUT_DIR + "/" + token + ".png")
    filename = token
    if token ==' ' or token == '':
      filename = 'empty'    
    grid.save(OUTPUT_DIR + "/" + filename + ".jpg", optimize=True, quality=60)

# ===========================================================================
def get_text_size(text_string, font):
  # https://stackoverflow.com/a/46220683/9263761
  ascent, descent = font.getmetrics()
  text_width = font.getmask(text_string).getbbox()[2]
  text_height = font.getmask(text_string).getbbox()[3] + descent
  #print("text length = " + str(text_width) )
  return (text_width, text_height)


In [None]:
#@title 2. Token / Class Words
#@markdown add your token words you would like to test (comma separated)
TOKENS = "sks, supercalifragilisticexpialidocious12344567" #@param {type:"string"}


delimeter = ''
if ',' in TOKENS:
    delimeter = ','
else:
    delimeter = ' ' #this will also clear empty strings
    
# split and remove any whitespaces
token_list = [x.strip() for x in TOKENS.split(delimeter)]

#for t in token_list:
# print(t)

#@markdown ___ 

#@markdown (Optional) add class to the list?

class_person = True #@param {type:"boolean"}
class_man = False #@param {type:"boolean"}
class_woman = False #@param {type:"boolean"}
class_artstyle = False #@param {type:"boolean"}
class_other = "" #@param {type:"string"}

token_and_person_list = []
token_and_man_list = []
token_and_woman_list = []
token_and_artstyle_list = []
token_and_other_list = []
# =============================================
if class_person==True:
  for t in token_list:
    token_and_person_list.append(f'{t} person')

# =============================================
if class_man==True:
  for t in token_list:
    token_and_man_list.append(f'{t} man')

# =============================================
if class_woman==True:
  for t in token_list:
    token_and_woman_list.append(f'{t} woman')

# =============================================
if class_artstyle==True:
  for t in token_list:
    token_and_artstyle_list.append(f'{t} artstyle')

# =============================================
if len(class_other) > 0 :
  for t in token_list:
    token_and_other_list.append(f'{t} {class_other}')

# add token+class to original list
token_list.extend(token_and_person_list)
token_list.extend(token_and_man_list)
token_list.extend(token_and_woman_list)
token_list.extend(token_and_artstyle_list)
token_list.extend(token_and_other_list)


for t in token_list:
    print(t)
#@markdown For every token in your list, the class word will be appended to that token and added to the list. \
#@markdown Ex. If you have `'wow, owo, sks'` in your list, and select `'person'` class, then the list would become:
#@markdown * wow
#@markdown * owo
#@markdown * sks
#@markdown * wow person
#@markdown * owo person
#@markdown * sks person
# add class word variant to list?


In [None]:
#@title 3. Generate!
generate_grid()

print("grids saved in 'output' directory")

In [None]:
#@title Delete images from dir /content/output
import os, shutil

dir = "/content/output"

#!rm -rf "$dir"

for filename in os.listdir(dir):
    file_path = os.path.join(dir, filename)
    try:
        if os.path.isfile(file_path) or os.path.islink(file_path):
            os.unlink(file_path)
        elif os.path.isdir(file_path):
            shutil.rmtree(file_path)
    except Exception as e:
        print('Failed to delete %s. Reason: %s' % (file_path, e))
        
print("you may need to refresh the file manager view if the files is still visible after deleting.")
