In [None]:
import os
import gc

import torch
import base64
from io import BytesIO

from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
from diffusers import DiffusionPipeline

In [None]:
# prompt setup

prompt = "a cat sitting on a pole"
model_id = "black-forest-labs/FLUX.1-dev"
width = 64
height = 64
num_inference_steps = 1
guidance_scale = 2
format = "JPEG"

In [None]:
# check gpu
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

# check gpu(s)
n_gpus = torch.cuda.device_count()
try:
    _ = f"{int(torch.cuda.mem_get_info()[0] / 1024 ** 3) - 2}GB"
except AssertionError:
    _ = 0
max_memory = {i: _ for i in range(n_gpus)}
print('max memory:', max_memory)

gc.collect()
torch.cuda.empty_cache()


In [None]:
# pipe setup

compute_capability = torch.cuda.get_device_properties(0).major
if compute_capability > 8:
    torch_dtype = torch.bfloat16
elif compute_capability>7:
    torch_dtype = torch.float16
else:
    torch_dtype = None  # auto setup for < 7

try:
    pipe = FluxPipeline.from_pretrained(model_id, torch_dtype=torch_dtype)
except Exception as e:
    base_model = "black-forest-labs/FLUX.1-dev"
    pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch_dtype)
    pipe.load_lora_weights(model_id)


# # for low GPU RAM, quantize from 16b to 8b
# quantize(pipe.transformer, weights=qfloat8)
# freeze(pipe.transformer)
# quantize(pipe.text_encoder_2, weights=qfloat8)
# freeze(pipe.text_encoder_2)

# # for even lower GPU RAM
# pipe.vae.enable_tiling()
# pipe.vae.enable_slicing()

pipe.enable_sequential_cpu_offload()



In [None]:
# generate image

image = pipe(
    prompt=prompt,
    width=width,
    height=height,
    num_inference_steps=num_inference_steps,
    guidance_scale=guidance_scale,
    generator=torch.Generator(device=device)
).images[0]


pipe = None
torch.cuda.empty_cache()

In [None]:
# format the result

# refs
# predictions.append({
#                     'result': [{
#                         'from_name': "generated_text",
#                         'to_name': "text_output", #audio
#                         'type': 'textarea',
#                         'value': {
#                             'data': base64_output,
#                             "url": generated_url, 
#                         }
#                     }],
#                     'model_version': ""
#                 })
#                 print(predictions)


buffered = BytesIO()
image.save(buffered, format=format)
img_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")

result = {
  "model_version": model_id,
  "result":{
    "format": format,
    "image": img_base64,
    # "image_url": //TODO: store in s3 bucket
  },
}

json_response = {"message": "predict completed successfully", "result": result}

print(json_response)