In [None]:
async def train_model(collection_s3_path, prompt, output_dir_name):
    s3 = boto3.resource('s3')
    collection_bucket, collection_key = parse_s3_uri(collection_s3_path)
    
    local_collection_path = '/tmp/training_data'
    download_from_s3(s3, collection_bucket, collection_key, local_collection_path)
    
    output_s3_path = f"s3://{collection_bucket}/model_outputs/{output_dir_name}"
    
    command = f"""
    accelerate launch train_dreambooth_lora_sdxl.py \
      --pretrained_model_name_or_path='/opt/ml/model/base'  \
      --instance_data_dir='{local_collection_path}' \
      --pretrained_vae_model_name_or_path="madebyollin/sdxl-vae-fp16-fix" \
      --output_dir="/tmp/trained_model" \
      --mixed_precision="fp16" \
      --instance_prompt="{prompt}" \
      --resolution=1024 \
      --train_batch_size=2 \
      --gradient_accumulation_steps=2 \
      --gradient_checkpointing \
      --learning_rate=1e-4 \
      --lr_scheduler="constant" \
      --lr_warmup_steps=0 \
      --max_train_steps=500 \
      --seed="0"
    """
    
    process = await asyncio.create_subprocess_shell(
        command,
        stdout=asyncio.subprocess.PIPE,
        stderr=asyncio.subprocess.PIPE
    )
    stdout, stderr = await process.communicate()
    
    if process.returncode == 0:
        upload_to_s3(s3, '/tmp/trained_model', collection_bucket, f"model_outputs/{output_dir_name}")
        return {"status": "success", "output_s3_path": output_s3_path}
    else:
        return {"status": "failed", "error": stderr.decode()}

def parse_s3_uri(uri):
    parts = uri.replace("s3://", "").split("/")
    bucket = parts.pop(0)
    key = "/".join(parts)
    return bucket, key

def download_from_s3(s3, bucket, key, local_path):
    os.makedirs(local_path, exist_ok=True)
    for obj in s3.Bucket(bucket).objects.filter(Prefix=key):
        if not obj.key.endswith('/'):
            target = os.path.join(local_path, os.path.relpath(obj.key, key))
            if not os.path.exists(os.path.dirname(target)):
                os.makedirs(os.path.dirname(target))
            s3.Bucket(bucket).download_file(obj.key, target)

def upload_to_s3(s3, local_dir, bucket, s3_path):
    for root, _, files in os.walk(local_dir):
        for file in files:
            local_file = os.path.join(root, file)
            relative_path = os.path.relpath(local_file, local_dir)
            s3_file = os.path.join(s3_path, relative_path)
            s3.Bucket(bucket).upload_file(local_file, s3_file)
        return True

## 3. Download and Prepare the Model

In [None]:
from distutils.dir_util import copy_tree
from pathlib import Path
from huggingface_hub import snapshot_download
import random
import os

# Set up model IDs and tokens
BASE_MODEL_ID = "stabilityai/stable-diffusion-xl-base-1.0"
REFINER_MODEL_ID = "stabilityai/stable-diffusion-xl-refiner-1.0"
HF_TOKEN = os.environ.get('HF_TOKEN')
assert len(HF_TOKEN) > 0, "Please set HF_TOKEN to your huggingface token."

# Create a unique directory for the model
model_tar = Path(f"model-{random.getrandbits(16)}")
model_tar.mkdir(exist_ok=True)

# Download and copy base model
print("Downloading base model...")
base_snapshot_dir = snapshot_download(repo_id=BASE_MODEL_ID, revision="main", use_auth_token=HF_TOKEN)
base_model_dir = model_tar / "base"
base_model_dir.mkdir(exist_ok=True)
copy_tree(base_snapshot_dir, str(base_model_dir))

# Download and copy refiner model
print("Downloading refiner model...")
refiner_snapshot_dir = snapshot_download(repo_id=REFINER_MODEL_ID, revision="main", use_auth_token=HF_TOKEN)
refiner_model_dir = model_tar / "refiner"
refiner_model_dir.mkdir(exist_ok=True)
copy_tree(refiner_snapshot_dir, str(refiner_model_dir))

# Create a directory for LoRA weights (assuming you have them)
lora_dir = model_tar / "Trained_lora"
lora_dir.mkdir(exist_ok=True)
# If you have LoRA weights, uncomment and modify the following line:
# copy_tree("path/to/your/lora/weights", str(lora_dir))

# Copy the code directory
code_dir = model_tar / "code"
code_dir.mkdir(exist_ok=True)
copy_tree("code/", str(code_dir))

print(f"Model files prepared in directory: {model_tar}")