In [None]:
from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, EulerAncestralDiscreteScheduler
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection, set_seed
from PIL import Image
from typing import Any, Dict, Optional
import torch
from torchvision import transforms
import numpy as np

class Zero123PlusPipeline(StableDiffusionPipeline):
    def __init__(self, vae, text_encoder, tokenizer, unet, scheduler, vision_encoder, feature_extractor_clip, feature_extractor_vae, ramping_coefficients, *args, **kwargs):
        super().__init__(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=None, feature_extractor=feature_extractor_vae, *args, **kwargs)

        self.vae = vae
        self.text_encoder = text_encoder
        self.tokenizer = tokenizer
        self.unet = unet
        self.scheduler = scheduler
        self.vision_encoder = vision_encoder
        self.feature_extractor_clip = feature_extractor_clip
        self.feature_extractor_vae = feature_extractor_vae
        self.ramping_coefficients = ramping_coefficients
        self.depth_transforms_multi = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])
        ])

    def prepare_unet_call(self, image: Image.Image, prompt: str, num_images_per_prompt: int, depth_image: Image.Image = None):
        assert not isinstance(image, torch.Tensor)
        
        image = self.to_rgb_image(image)
        image_1 = self.feature_extractor_vae(images=image, return_tensors="pt").pixel_values
        image_2 = self.feature_extractor_clip(images=image, return_tensors="pt").pixel_values
        
        if depth_image is not None and hasattr(self.unet, "controlnet"):
            depth_image = self.to_rgb_image(depth_image)
            depth_image = self.depth_transforms_multi(depth_image).to(
                device=self.unet.controlnet.device, dtype=self.unet.controlnet.dtype
            )

        image_1 = image_1.to(device=self.vae.device, dtype=self.vae.dtype)
        image_2 = image_2.to(device=self.vae.device, dtype=self.vae.dtype)

        cond_lat = self.encode_condition_image(image_1)
        
        if hasattr(self, "encode_prompt"):
            encoder_hidden_states = self.encode_prompt(prompt, self.device, num_images_per_prompt, False)[0]
        else:
            encoder_hidden_states = self._encode_prompt(prompt, self.device, num_images_per_prompt, False)

        encoded = self.vision_encoder(image_2, output_hidden_states=False)
        global_embeds = encoded.image_embeds.unsqueeze(-2)
        
        ramp = global_embeds.new_tensor(self.ramping_coefficients).unsqueeze(-1)
        encoder_hidden_states = encoder_hidden_states + global_embeds * ramp
        cross_attention_kwargs = dict(cond_lat=cond_lat)
        
        if hasattr(self.unet, "controlnet"):
            cross_attention_kwargs['control_depth'] = depth_image

        return image_1, cond_lat, encoder_hidden_states, cross_attention_kwargs
        
    def encode_condition_image(self, image: torch.Tensor):
        image = self.vae.encode(image).latent_dist.sample()
        return image
        
    def run_pipeline(self, image_1: torch.Tensor, cond_lat: torch.Tensor, encoder_hidden_states: torch.Tensor, cross_attention_kwargs: Dict[str, Any], guidance_scale: float, num_images_per_prompt: int, **kwargs):
        if guidance_scale > 1:
            negative_lat = self.encode_condition_image(torch.zeros_like(image_1))
            cond_lat = torch.cat([negative_lat, cond_lat])

        latents = super().__call__(
            None,
            cross_attention_kwargs=cross_attention_kwargs,
            guidance_scale=guidance_scale,
            num_images_per_prompt=num_images_per_prompt,
            prompt_embeds=encoder_hidden_states,
            output_type='latent',
            **kwargs
        ).images

        return latents

    def process_output(self, latents: torch.Tensor, output_type: Optional[str] = "pil", return_dict: bool = True):
        latents = self.unscale_latents(latents)
        if output_type != "latent":
            image = self.unscale_image(self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0])
        else:
            image = latents

        image = self.image_processor.postprocess(image, output_type=output_type)
        
        if not return_dict:
            return image

        return ImagePipelineOutput(images=image)

    @torch.no_grad()
    def generate(self, image: Image.Image = None, prompt: str = "", *args, num_images_per_prompt: Optional[int] = 1, guidance_scale: float = 4.0, depth_image: Image.Image = None, output_type: Optional[str] = "pil", width: int = 640, height: int = 960, num_inference_steps: int = 28, return_dict: bool = True, **kwargs):
        if isinstance(image, str):
            image = Image.open(image).convert("RGB")

        if image is None:
            raise ValueError("Inputting embeddings not supported for this pipeline. Please pass an image.")

        image_1, cond_lat, encoder_hidden_states, cross_attention_kwargs = self.prepare_unet_call(image, prompt, num_images_per_prompt, depth_image)
        latents = self.run_pipeline(image_1, cond_lat, encoder_hidden_states, cross_attention_kwargs, guidance_scale, num_images_per_prompt, **kwargs)
        return self.process_output(latents, output_type, return_dict)

    @staticmethod
    def to_rgb_image(maybe_rgba: Image.Image):
        if maybe_rgba.mode == 'RGB':
            return maybe_rgba
        elif maybe_rgba.mode == 'RGBA':
            rgba = maybe_rgba
            img = np.random.randint(127, 128, size=[rgba.size[1], rgba.size[0], 3], dtype=np.uint8)
            img = Image.fromarray(img, 'RGB')
            img.paste(rgba, mask=rgba.getchannel('A'))
            return img
        else:
            raise ValueError("Unsupported image type.", maybe_rgba.mode)

    @staticmethod
    def unscale_latents(latents):
        return (latents / 0.75) + 0.22

    @staticmethod
    def unscale_image(image):
        return image / 0.5 * 0.8


In [None]:
import os
import gc
import torch
from PIL import Image
from tqdm.auto import tqdm
from torchvision import transforms
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection, set_seed, CLIPImageProcessor
from diffusers import AutoencoderKL, UNet2DConditionModel, EulerAncestralDiscreteScheduler

class Zero123PlusSimSim314Generator(Zero123PlusPipeline):
    def __init__(self, config):
        self.pipeline_config = config
       
        vae = AutoencoderKL.from_pretrained(config['pretrained_model_name_or_path'], subfolder="vae").to(self.device)
        text_encoder = CLIPTextModel.from_pretrained(config['pretrained_model_name_or_path'], subfolder="text_encoder").to(self.device)
        tokenizer = CLIPTokenizer.from_pretrained(config['pretrained_model_name_or_path'], subfolder="tokenizer")
        unet = UNet2DConditionModel.from_pretrained(config['pretrained_model_name_or_path'], subfolder="unet").to(self.device)
        scheduler = EulerAncestralDiscreteScheduler.from_pretrained(config['pretrained_model_name_or_path'], subfolder="scheduler")
        vision_encoder = CLIPVisionModelWithProjection.from_pretrained(config['pretrained_model_name_or_path'], subfolder="vision_encoder").to(self.device)
        
        feature_extractor_clip = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
        feature_extractor_vae = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")

        super().__init__(
            vae=vae,
            text_encoder=text_encoder,
            tokenizer=tokenizer,
            unet=unet,
            scheduler=scheduler,
            vision_encoder=vision_encoder,
            feature_extractor_clip=feature_extractor_clip,
            feature_extractor_vae=feature_extractor_vae,
            ramping_coefficients=config["ramping_coefficients"]
        )
        self.set_seed()

    def set_seed(self):
        set_seed(self.pipeline_config['seed'])

    @torch.no_grad()
    def __call__(self, image: Image.Image = None, prompt: str = "", *args, num_images_per_prompt: Optional[int] = 1, guidance_scale: float = 4.0, depth_image: Image.Image = None, output_type: Optional[str] = "pil", width: int = 640, height: int = 960, num_inference_steps: int = 28, return_dict: bool = True, **kwargs):
        
        if isinstance(image, str):
            image = Image.open(image).convert("RGB")

        if image is None:
            raise ValueError("Inputting embeddings not supported for this pipeline. Please pass an image.")

        image_1, cond_lat, encoder_hidden_states, cross_attention_kwargs = self.prepare_unet_call(image, prompt, num_images_per_prompt, depth_image)
        
        self.scheduler.set_timesteps(self.pipeline_config['num_inference_steps'])
        latents = self.vae.encode(image_1).latent_dist.sample() * self.scheduler.init_noise_sigma

        for i, t in enumerate(tqdm(self.scheduler.timesteps, desc="Generating images")):
            latent_model_input = self.scheduler.scale_model_input(latents, t)
            
            with torch.no_grad():
                noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs).sample
            
            latents = self.scheduler.step(noise_pred, t, latents).prev_sample
            
            with torch.no_grad():
                image = self.vae.decode(1 / 0.18215 * latents).sample
            
            image = (image / 2 + 0.5).clamp(0, 1)
            image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
            images = (image * 255).round().astype("uint8")
            pil_image = Image.fromarray(images[0])
            
            output_path = os.path.join(self.pipeline_config['output_dir'], f"generated_image_{i}.png")
            pil_image.save(output_path)

        self.cleanup()
        return self.process_output(latents, output_type, return_dict)

    def cleanup(self):
        torch.cuda.empty_cache()
        gc.collect()

config = {
    "input_image_path": "000008_021b957fdc234cf09da4a069cdf57a9b_0.png",
    "output_dir": "path_to_output_dir",
    "pretrained_model_name_or_path": "sudo-ai/zero123plus-v1.1",
    "resolution": 512,
    "seed": 42,
    "guidance_scale": 1.0,
    "num_inference_steps": 150,
      "ramping_coefficients": [
    0.0,
    0.2060057818889618,
    0.18684479594230652,
    0.24342191219329834,
    0.18507817387580872,
    0.1703828126192093,
    0.15628913044929504,
    0.14174538850784302,
    0.13617539405822754,
    0.13569170236587524,
    0.1269884556531906,
    0.1200924888253212,
    0.12816639244556427,
    0.13058121502399445,
    0.14201879501342773,
    0.15004529058933258,
    0.1620427817106247,
    0.17207716405391693,
    0.18534132838249207,
    0.20002241432666779,
    0.21657466888427734,
    0.22996725142002106,
    0.24613411724567413,
    0.25141021609306335,
    0.26613450050354004,
    0.271847128868103,
    0.2850190997123718,
    0.285749226808548,
    0.2813953757286072,
    0.29509517550468445,
    0.30109965801239014,
    0.31370124220848083,
    0.3134534955024719,
    0.3108579218387604,
    0.32147032022476196,
    0.33548328280448914,
    0.3301997184753418,
    0.3254660964012146,
    0.3514464199542999,
    0.35993096232414246,
    0.3510829508304596,
    0.37661612033843994,
    0.3913513123989105,
    0.42122599482536316,
    0.3954688012599945,
    0.4260983467102051,
    0.479139506816864,
    0.4588979482650757,
    0.4873477816581726,
    0.5095643401145935,
    0.5133851170539856,
    0.520708441734314,
    0.5363377928733826,
    0.5661528706550598,
    0.5859065651893616,
    0.6207258701324463,
    0.6560986638069153,
    0.6379964351654053,
    0.6777164340019226,
    0.6589891910552979,
    0.7574057579040527,
    0.7446827292442322,
    0.7695522308349609,
    0.8163619041442871,
    0.9502472281455994,
    0.9918442368507385,
    0.9398387670516968,
    1.005432367324829,
    0.9295969605445862,
    0.9899859428405762,
    1.044832706451416,
    1.0427014827728271,
    1.0829696655273438,
    1.0062562227249146,
    1.0966323614120483,
    1.0550328493118286,
    1.2108079195022583
  ]
}

os.makedirs(config['output_dir'], exist_ok=True)

generator = Zero123PlusSimSim314Generator(config)
generator.generate(config['input_image_path'])
