## Intro
Run flux text to image on apple silicon
* black-forest-labs/FLUX.1-schnell https://huggingface.co/black-forest-labs/FLUX.1-schnell

In [None]:
import torch
import os
import applyllm as apl

print(apl.__version__)

In [None]:
if not torch.backends.mps.is_available():
    print("MPS is not available")
else:
    print("MPS is available")
    mps_device = torch.device("mps")
    print(mps_device)

In [None]:
from applyllm.accelerators import (
    DirectorySetting,
    TokenHelper,
)
from applyllm.utils import time_func

dir_mode_map = {
    "kf_notebook": DirectorySetting(),
    "mac_local": DirectorySetting(home_dir="/Users/yingding", transformers_cache_home="MODELS", huggingface_token_file="MODELS/.huggingface_token"),
}

model_map = {
    "diffusion-v1.5":     "runwayml/stable-diffusion-v1-5",
    "stability-sd3-medium": "stabilityai/stable-diffusion-3-medium-diffusers",
    "free-flux-1-12B": "black-forest-labs/FLUX.1-schnell",
}

default_model_type = "free-flux-1-12B"
default_dir_mode = "mac_local"

dir_setting = dir_mode_map[default_dir_mode]

os.environ["WORLD_SIZE"] = "1" 
os.environ['XDG_CACHE_HOME'] = dir_setting.get_cache_home()

print(os.environ['XDG_CACHE_HOME'])

In [None]:
import diffusers

print(diffusers.__version__)
print(torch.__version__)

In [None]:
# model_type = "diffusion-v1.5"
# model_type = "stability-sd3-medium"
model_type = "free-flux-1-12B"

model_name = model_map.get(model_type, default_model_type)
print(model_name)

In [None]:
# from applyllm.pipelines import (
#     ModelCatalog,
#     KwargsBuilder
# )
th = TokenHelper(dir_setting=dir_setting, prefix_list=["llama", "stability"])
token_kwargs = th.gen_token_kwargs(model_type=model_type)

# data_type = torch.bfloat16
data_type = torch.float16
# data_type = torch.float32
device_map = "mps"
# auto caste not working for mps 4.38.2
# https://github.com/huggingface/transformers/issues/29431 

model_kwargs = {
    "torch_dtype": data_type, #bfloat16 is not supported on MPS backend, float16 only on GPU accelerator
    "pretrained_model_name_or_path": model_name,
    "use_fast": True, # use fast tokenizers
    "height": (height := 512),
    "width": (width := 512),
    "num_images_per_prompt": 1,
    "num_inference_steps": 10, # 28 for 512x512, 56 for 1024x1024
    "guidance_scale": 7.0,
}
print(f"model_kwargs: {model_kwargs}")

In [None]:
from diffusers import FluxPipeline

pipe = FluxPipeline.from_pretrained(**model_kwargs, **token_kwargs)
# pipe.enable_model_cpu_offload() #save some VRAM by offloading the model to CPU. Remove this if you have enough GPU power
# pipe.vae.enable_tiling()
# pipe = pipe.to(torch.float16)
pipe = pipe.to(device_map)

prompt = "a photo of an astronaut riding a horse on mars"

@time_func
def img_gen(prompt: str):
    return pipe(prompt).images[0]

image = img_gen(prompt)

In [None]:
# create a directory "imgs" if it is not exist in the current working directory
img_dir = "imgs/flux1"
os.makedirs(img_dir, exist_ok=True)

In [None]:
img_file_name = "astronaut_on_mars.png"

def save_image(image, img_file_name, overwrite=False):
    img_file_path = os.path.join(img_dir, img_file_name)
    # save the image to the file if not exist
    if overwrite or not os.path.exists(img_file_path):
        image.save(img_file_path)
    return img_file_path

img_file_path = save_image(image, img_file_name)

In [None]:
# reload and display the image file from the variable img_file_path in the Jupyter notebook cell inline
from PIL import Image
img = Image.open(img_file_path)
# show inline
display(img)
# show in extra window
# img.show()


In [None]:
prompt = "a dog with a red hat riding a skateboard on the street"

image = img_gen(prompt)

img_file_name = "dog_with_red_hat_on_skateboard.png"
img_file_path = save_image(image, img_file_name)

In [None]:
img = Image.open(img_file_path)
# show inline
display(img)

In [None]:
prompt = "a radish vegetable with sunglasses on top of it, sits on a beach chair with sand, ocean, palm tree in the background"

image = img_gen(prompt)

img_file_name = "radish_with_sunglasses_on_beach.png"
img_file_path = save_image(image, img_file_name, overwrite=True)

In [None]:
img = Image.open(img_file_path)
# show inline
display(img)

In [None]:
# prompt = "a strawberry unicorn in impressionist style dancing in a field with flowers and butterflies flying around"
prompt = "a unicorn with a strawberry at the top of its horn in impressionist style dancing in a field with flowers and butterflies flying around"

image = img_gen(prompt)

img_file_name = "strawberry_dance.png"
img_file_path = save_image(image, img_file_name, overwrite=True)

In [None]:
img = Image.open(img_file_path)
# show inline
display(img)

In [None]:
prompt = "dark dog wearing a black sunglasses like a cool dude"

image = img_gen(prompt)

img_file_name = "dog_sunglasses.png"
img_file_path = save_image(image, img_file_name, overwrite=True)

In [None]:
img = Image.open(img_file_path)
# show inline
display(img)

In [None]:
prompt = "A cat holding a sign that says happy vacations"

image = img_gen(prompt)

img_file_name = "cat_sign_happy_vacations.png"
img_file_path = save_image(image, img_file_name, overwrite=True)

In [None]:
img = Image.open(img_file_path)
# show inline
display(img)