<a href="https://colab.research.google.com/github/snrism/sdxl-tuner/blob/main/Fine_tune_Images_SDXL_Model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# SDXL Fine-tuning

Stability AI open-sourced [SDXL](https://replicate.com/blog/run-sdxl-with-an-api), that allows you to fine tune diffusion models to learn patterns in input images.

In this colab, you can use [Replicate](https://replicate.com) via [running from the web](https://replicate.com/stability-ai/sdxl) or using the [API](https://replicate.com/blog/run-sdxl-with-an-api) to tune a model.

In [16]:
!pip install replicate
import os
import replicate
from google.colab import output
output.clear()

Authenticate by setting your token in an environment variable:

In [23]:
# get your token from https://replicate.com/account
from getpass import getpass

REPLICATE_API_TOKEN = getpass()
os.environ["REPLICATE_API_TOKEN"] = REPLICATE_API_TOKEN

··········


## Prepare your training images

- Images can be of yourself or a particular style like illustrations. - Images can be in JPEG or PNG format.
- Dimensions and size don't matter.
- Filenames don't matter.

Zip your images.

```console
zip -r data.zip data
```

Upload this file somewhere on the internet that is publicly accessible, like an GCS bucket, S3 bucket or a GitHub Pages site.


## Create a model

Create a model by reusing the owner name below. Set the model name and feel free to reuse the devault visibility and hardware values.

You can also create a model via [replicate.com/create](https://replicate.com/create).

In [None]:
import replicate

client = replicate.Client(api_token=REPLICATE_API_TOKEN)
client.models.create(
    owner="snrism",
    name="YOUR_MODEL_NAME",
    visibility="private",
    hardware="gpu-a40-large"
)

Use the following to get the model object to retrieve the latest version of the model.

Look for the version attribute which is required to create a new training instance below.

In [None]:
model = client.models.get("stability-ai/sdxl")
print(model)

## Start the training

In [38]:
# NOTE: Below is an example of using SDXL model from Stability AI. It does a pretty good job of capturing style and reflecting that in the output.
# Recently, FLUX from BlackForest Labs (https://blackforestlabs.ai/) created a better image genreation model that produced high quality images. Give it a shot by replacing the version below
# with FLUX's model version="ostris/flux-dev-lora-trainer:885394e6a31c6f349dd4f9e6e7ffbabd8d9840ab2559ab78aed6b2451ab2cfef",
# input={
#   "steps": 1000,
#    "lora_rank": 16, -> More adaptable model, but higher memory and computation requirements. It can capture more complex nuances in your data but risks overfitting if set too high.
#    "optimizer": "adamw8bit",
#    "batch_size": 1,
#    "resolution": "512,768,1024",
#    "autocaption": True,
#    "input_images": "https://",
#    "trigger_word": "TOK",
#    "learning_rate": 0.0004,
#    "wandb_project": "flux_train_replicate",
#    "wandb_save_interval": 100,
#    "caption_dropout_rate": 0.05,
#    "cache_latents_to_disk": False,
#    "wandb_sample_interval": 100
#  },
###

training = client.trainings.create(
    version="stability-ai/sdxl:c221b2b8ef527988fb59bf24a8b97c4561f1c671f73bd389f866bfb27c061316",
    input={
        "input_images": "YOUR_ZIP_FILE",
        "caption_prefix": "UNIQUE CAPTION USED IN YOUR PROMPT. YOU CAN REMOVE IT",
        "token_string": "UNIQUE IDENTIFIER TO REFER TO YOUR IMAGES. e.g., a photo of TOK",
        "use_face_detection_instead": False,
        "is_lora": True,
    },
    destination="snrism/YOUR_MODEL_NAME"
)

## Monitor training progress

To follow the progress of the training job, run the following code to track the training status.

In [None]:
import time

# Continuously reload the training object and check the status
while True:
    # Reload the training status
    training.reload()

    # Print the current status
    print(f"Status: {training.status}")

    # Check if the status is 'processing'
    if training.status == 'processing':
        # Show the last 10 lines of logs
        print("\n".join(training.logs.split("\n")[-10:]))

    elif training.status == 'succeeded':
        # Print final logs when succeeded and break the loop
        print("\nTraining succeeded! Here are the final logs:")
        break

    elif training.status == 'failed':
        # If training failed, print the error message and break the loop
        print("Training failed. Here are the logs:")
        print("\n".join(training.logs.split("\n")))
        break

    # Wait for a few seconds before rechecking (you can adjust the interval)
    time.sleep(2)

## Run the model

When the model has finished training you can run it using the GUI on replicate.com/my-name/my-model, or via the API:


In [None]:
output = client.run(
    training.output["version"],
    input={"prompt": "kids playing in the park with pinata"},
)
print(output)