In [1]:
import torch
from PIL import Image
from torchvision import transforms
from transformers import CLIPProcessor, CLIPModel, CLIPTextModel, CLIPTokenizer
from tqdm import trange

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler 

In [3]:
import clip
import numpy as np

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
style_prompt = "in the style of a renaissance painting"
num_steps = 50
guidance_scale = 7.5
lr = 0.05

In [5]:
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae").to(device)
unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet").to(device)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
scheduler = DDIMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")

In [6]:
clip_model, clip_preprocess = clip.load("ViT-L/14", device=device)

In [1]:
# def div_by_8(x):
#     return x - (x % 8)

In [8]:
clip_tokens = clip.tokenize([style_prompt]).to(device)
with torch.no_grad():
    style_embedding = clip_model.encode_text(clip_tokens).float()

In [9]:
input_image = Image.open("objects.jpg").convert("RGB").resize((512, 512))
# o_width, o_height = input_image.size
# a_width, a_height = div_by_8(o_width), div_by_8(o_height)
# resized_input = input_image.resize((a_width, a_height), resample=Image.LANCZOS)
preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

In [10]:
text_token = clip.tokenize([style_prompt]).to(device)
with torch.no_grad():
    text_features = clip_model.encode_text(text_token).float()

In [12]:
image_tensor = preprocess(input_image).unsqueeze(0).to(device)
with torch.no_grad():
    init_latents = vae.encode(image_tensor * 2 - 1).latent_dist.sample() * 0.18215

In [13]:
text_input = tokenizer(style_prompt, return_tensors="pt").input_ids.to(device)
with torch.no_grad():
    text_embeddings = text_encoder(text_input)[0]

In [14]:
latents = init_latents.clone().detach().requires_grad_(True)
optimizer = torch.optim.Adam([latents], lr=lr)

In [15]:
# prompt = "a photo of an astronaut riding a horse on mars"
# with autocast("cuda"):
#     image = pipe(prompt).images[0]  
    
# image.save("astronaut_rides_horse.png")

In [16]:
clip_model.eval()
vae.eval()

for step in trange(num_steps, desc="CLIP Optimization"):
    noise_pred = unet(latents, torch.tensor([scheduler.timesteps[0]], device=device), encoder_hidden_states=text_embeddings).sample
    latents_denoised = latents - noise_pred  # simplified step
        
    decoded = vae.decode(latents_denoised / 0.18215).sample
    decoded = (decoded.clamp(-1, 1) + 1) / 2
    clip_input = torch.nn.functional.interpolate(decoded, size=224, mode="bicubic", align_corners=False)
    clip_input = (clip_input - 0.5) / 0.5
    image_features = clip_model.encode_image(clip_input).float()

    loss = -torch.cosine_similarity(image_features, style_embedding).mean()
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if step % 10 == 0:
        print(f"Step {step} - CLIP Loss: {loss.item():.4f}")


CLIP Optimization:   2%|██▍                                                                                                                        | 1/50 [00:28<23:23, 28.63s/it]

Step 0 - CLIP Loss: -0.1476


CLIP Optimization:  22%|██████████████████████████▊                                                                                               | 11/50 [06:59<25:30, 39.25s/it]

Step 10 - CLIP Loss: -0.2035


CLIP Optimization:  42%|███████████████████████████████████████████████████▏                                                                      | 21/50 [13:30<19:03, 39.44s/it]

Step 20 - CLIP Loss: -0.2142


CLIP Optimization:  62%|███████████████████████████████████████████████████████████████████████████▋                                              | 31/50 [20:04<12:33, 39.67s/it]

Step 30 - CLIP Loss: -0.2227


CLIP Optimization:  82%|████████████████████████████████████████████████████████████████████████████████████████████████████                      | 41/50 [26:39<05:55, 39.55s/it]

Step 40 - CLIP Loss: -0.2312


CLIP Optimization: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [32:33<00:00, 39.07s/it]


In [18]:
with torch.no_grad():
    final_image = vae.decode(latents / 0.18215).sample
    final_image = (final_image.clamp(-1, 1) + 1) / 2
    final_image = final_image.squeeze().permute(1, 2, 0).cpu().numpy()
    final_image = (final_image * 255).astype(np.uint8)
    # final_image = final_image.resize((orig_w, orig_h), Image.LANCZOS)
    Image.fromarray(final_image).save("clip_guided_stylized_output.png")

In [17]:
# torch.clear_autocast_cache()

In [4]:
torch.cuda.empty_cache()
torch.cuda.synchronize()