In [None]:
%%writefile code/inference.py
import os
import torch
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
from compel import Compel, ReturnedEmbeddingsType
import json
import base64
from io import BytesIO
import asyncio
import subprocess
import boto3

def model_fn(model_dir):
    base_path = os.path.join(model_dir, 'base')
    refiner_path = os.path.join(model_dir, 'refiner')
    lora_path = os.path.join(model_dir, 'Trained_lora')

    base = DiffusionPipeline.from_pretrained(
        base_path,
        torch_dtype=torch.float16,
        variant="fp16",
        use_safetensors=True,
    ).to("cuda")
    
    base.load_lora_weights(
        lora_path,
        weight_name="pytorch_lora_weights.safetensors"
    )
    
    refiner = DiffusionPipeline.from_pretrained(
        refiner_path,
        text_encoder_2=base.text_encoder_2,
        vae=base.vae,
        torch_dtype=torch.float16,
        use_safetensors=True,
        variant="fp16",
    ).to("cuda")

    compel = Compel(
        tokenizer=[base.tokenizer, base.tokenizer_2],
        text_encoder=[base.text_encoder, base.text_encoder_2],
        returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
        requires_pooled=[False, True])

    compel_refiner = Compel(
        tokenizer=[refiner.tokenizer_2],
        text_encoder=[refiner.text_encoder_2],
        returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
        requires_pooled=[True],
    )

    return base, refiner, compel, compel_refiner

def predict_fn(data, models):
    if data.get("action") == "train":
        loop = asyncio.get_event_loop()
        return loop.run_until_complete(train_model(
            data["collection_s3_path"],
            data["prompt"],
            data["output_dir_name"]
        ))
    else:
        base, refiner, compel, compel_refiner = models
        prompt = data.pop("prompt", "")
        negative_prompt = data.pop("negative_prompt", "")

        conditioning, pooled = compel(prompt)
        negative_conditioning, negative_pooled = compel(negative_prompt)
        conditioning_refiner, pooled_refiner = compel_refiner(prompt)
        negative_conditioning_refiner, negative_pooled_refiner = compel_refiner(negative_prompt)

        image = base(
            prompt_embeds=conditioning,
            pooled_prompt_embeds=pooled,
            negative_prompt_embeds=negative_conditioning,
            negative_pooled_prompt_embeds=negative_pooled,
            num_inference_steps=40,
            denoising_end=0.8,
            output_type="latent",
        ).images[0]

        refiner_result = refiner(
            prompt_embeds=conditioning_refiner,
            pooled_prompt_embeds=pooled_refiner,
            negative_prompt_embeds=negative_conditioning_refiner,
            negative_pooled_prompt_embeds=negative_pooled_refiner,
            num_inference_steps=40,
            denoising_start=0.8,
            image=image,
        ).images[0]

        buffered = BytesIO()
        refiner_result.save(buffered, format="PNG")
        img_str = base64.b64encode(buffered.getvalue()).decode()

        return {'image': img_str}
