In [None]:
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
from diffusers.utils import make_image_grid
from image_noiser import ImageNoiser
import torch
import open_clip
from artifact_estimator.model import load_model, preprocess


In [None]:
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to('cuda')
pipe.enable_sequential_cpu_offload()
pipe.vae.enable_slicing()
pipe.vae.enable_tiling()
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)

image_noiser = ImageNoiser(sd_pipe=pipe) #It'll use pipe's components

clip, _, preprocess_clip = open_clip.create_model_and_transforms('hf-hub:Outrun32/CLIP-ViT-B-16-noise-tuned')
clip.to('cuda')

mlp = load_model('artifact_estimator/models/artifact_estimator_openclip_vit_b_16.pth').to('cuda')


In [None]:
artifacted = []
PREDICTION_THRESHOLD = 3.0

def latents_callback_prediction_with_removal(pipe, i, t, kwargs):
    global artifacted
    if i%10 == 0 and i > 0: #For now only steps 10 and 20, later should write a custom pipeline to remove images from batch
        latents = kwargs['latents']
        prompt_embeds = kwargs['prompt_embeds']
        negative_prompt_embeds = kwargs['negative_prompt_embeds']
        approximated_latents = image_noiser.approx_latents_batch(latents)
        clip_latents = [clip.encode_image(preprocess_clip(approximated_latent).unsqueeze(0).to('cuda')) for approximated_latent in approximated_latents]
        with torch.no_grad():
            predictions = [mlp(preprocess(clip_latent)) for clip_latent in clip_latents]
        for i, pred in enumerate(predictions):
            if pred > PREDICTION_THRESHOLD:
                if i not in artifacted:
                    print('One of the images is artifacted, removing from batch!')
                    artifacted.append(i)
                # latents = torch.cat((latents[:i], latents[i+1:]))
                # negative_prompt_embeds = torch.cat((negative_prompt_embeds[:i], negative_prompt_embeds[i+1:])) #Probably no way to remove image from batch without going into batch
                # prompt_embeds = torch.cat([torch.cat((embed[:i], embed[i+1:])) for embed in prompt_embeds.chunk(2)])
                # print(prompt_embeds.shape)
            
            
    # kwargs["latents"] = latents
    # kwargs["propmt_embeds"] = prompt_embeds
    # kwargs["negative_prompt_embeds"] = negative_prompt_embeds
    return kwargs
    # images.append(approximated_latents)

In [None]:
num_images = 4
prompt = "A photo of deformed doggo, extremely distorted, bad, artifacts, horrible"
images.clear()
artifacted.clear()
images = pipe(prompt, num_inference_steps=20, callback_on_step_end=latents_callback_prediction_with_removal,
              num_images_per_prompt=num_images, callback_on_step_end_tensor_inputs=["latents", "prompt_embeds", "negative_prompt_embeds"]).images
final_images = [image for i, image in enumerate(images) if i not in artifacted]
if len(final_images) > 0:
    grid = make_image_grid(final_images, 1, len(final_images))
else:
    print("All images were artifacted")
grid