In [None]:
import urllib.request
import os

checkpoint_path = os.path.join(
    "taming-transformers", "models", "vqgan_imagenet_f16_16384", "checkpoints"
)
config_path = os.path.join(
    "taming-transformers", "models", "vqgan_imagenet_f16_16384", "configs"
)

os.makedirs(checkpoint_path, exist_ok=True)
os.makedirs(config_path, exist_ok=True)

checkpoint_path = os.path.join(checkpoint_path,"last.ckpt")
config_path = os.path.join(config_path, "model.yaml")

checkpoint_url = "https://heibox.uni-heidelberg.de/f/867b05fc8c4841768640/?dl=1"
config_url = "https://heibox.uni-heidelberg.de/f/274fb24ed38341bfa753/?dl=1"


if not os.path.exists(checkpoint_path):
    urllib.request.urlretrieve(checkpoint_url,checkpoint_path)
else:
    print("Last.ckpt already exists")

if not os.path.exists(config_path):
    urllib.request.urlretrieve(config_url, config_path)
else:
    print("model.yaml already exists")

In [None]:
# ## install some extra libraries
# !pip install --no-deps ftfy regex tqdm
# !pip install omegaconf==2.0.0 pytorch-lightning==1.0.8
# !pip uninstall torchtext --yes
# !pip install einops
# !pip install openai-clip
# !pip install taming-transformers

In [None]:
import numpy as np
import torch, imageio, pdb, math
import torchvision.transforms 
import torchvision.transforms.functional 

import PIL
import matplotlib.pyplot as plt 

import yaml
from omegaconf import OmegaConf

import clip

In [None]:
# helper functions

def show_from_tensor(tensor):
    img = tensor.clone()
    img = img.mul(255).byte()
    img = img.cpu().numpy().transpose((1, 2, 0))

    plt.figure(figsize=(10, 7))
    plt.axis("off")
    plt.imshow(img)
    plt.show()

def save_tensor(tensor, path):
    img = tensor.clone()
    img = img.mul(255).byte()
    img = img.cpu().numpy().transpose((1, 2, 0))

    plt.figure(figsize=(10, 7))
    plt.axis("off")
    plt.imshow(img)
    plt.savefig(path)
    
def norm_data(data):
    return (data.clip(-1, 1) + 1) / 2


# Parameters
learning_rate = 0.5
batch_size = 1
weight_decay = 0.1
noise_factor = 0.22

w1 = 1
w2 = 1

total_iteration = 100
img_shape = [225, 400, 3]  # height, width, channel
size1, size2, channels = img_shape

In [None]:
# CLIP MODEL
clipmodel, _ = clip.load("ViT-B/32", jit=False)
clipmodel.eval()
print(clip.available_models())

print("Clip model visual input resolution:", clipmodel.visual.input_resolution)

device = torch.device("cuda")
torch.cuda.empty_cache()

In [None]:
from taming.models.vqgan import VQModel

def load_config(config_path, display=False):
    config_data = OmegaConf.load(config_path)
    if display:
        print(yaml.dump(OmegaConf.to_container(config_data)))
    return config_data

def load_vqgan(config, chk_path=None):
    model = VQModel(**config.model.params)
    if chk_path is not None:
        state_dict = torch.load(chk_path, map_location="cpu")["state_dict"]
        missing, unexpected = model.load_state_dict(state_dict, strict=False)
    return model.eval()

def generator(x):
    x = taming_model.post_quant_conv(x)
    x = taming_model.decoder(x)
    return x

taming_config = load_config(config_path=config_path)
taming_model = load_vqgan(config=taming_config, chk_path=checkpoint_path).to(device)

In [None]:
### Declare the values that we are going to optimize
class Parameters(torch.nn.Module):
    def __init__(self):
        super(Parameters,self).__init__()
        self.data = 0.5*torch.randn(batch_size,256,size1//16, size2//16).cuda()
        self.data = torch.nn.Parameter(torch.sin(self.data))

    def forward(self):
        return self.data

def init_params():
  params=Parameters().cuda()
  optimizer = torch.optim.AdamW([{'params':[params.data], 'lr': learning_rate}], weight_decay=weight_decay)
  return params, optimizer
        

In [None]:
## Encodings, prompts and ....
normalize = torchvision.transforms.Normalize(
    (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)
)

def encodeText(text):
    t = clip.tokenize(text).cuda()
    t = clipmodel.encode_text(t).detach().clone()
    return t

def createEncodings(include, exclude, extras):
    include_enc = []
    for text in include:
        include_enc.append(encodeText(text))
    exclude_enc = encodeText(exclude) if exclude != "" else 0
    extras_enc = encodeText(extras) if extras != "" else 0
    
    return include_enc, exclude_enc, extras_enc

augTransform = torch.nn.Sequential(
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.RandomAffine(30,(0.2,0.2), fill=0)
).cuda()

Params, optimizer = init_params()

with torch.no_grad():
    print(Params().shape)
    img = norm_data(generator(Params()).cpu())
    print("img dimension: ", img.shape)
    show_from_tensor(img[0])

In [None]:
## Create crops
def create_crops(img, num_crops=30):
    p=size1//2
    img = torch.nn.functional.pad(img, (p,p,p,p), mode='constant', value=0) # 1 x 3 x 448 x 624

    img = augTransform(img) #RandomHorizontalFilp and RandomAffine

    crop_set = []
    for ch in range(num_crops):
        gap1 = int(torch.normal(1.0,0.5,()).clip(0.2,1.5)*size1)
        offsetx = torch.randint(0, int(size1*2 - gap1),())
        offsety = torch.randint(0, int(size1*2 - gap1), ())
        
        crop = img[:,:,offsetx:offsetx+gap1, offsety:offsety+gap1]
        
        crop = torch.nn.functional.interpolate(crop,(224,224),mode="bilinear",align_corners=True)
        crop_set.append(crop)
    img_crops = torch.cat(crop_set,0)
    
    randnormal = torch.randn_like(img_crops,requires_grad=False)
    num_rands = 0
    randstotal = torch.rand((img_crops.shape[0],1,1,1)).cuda()
    for ns in range(num_rands):
        randstotal*= torch.rand((img_crops.shape[0],1,1,1)).cuda()
    
    img_crops = img_crops + noise_factor * randstotal * randnormal
    
    return img_crops

In [None]:
def showme(Params, show_crop):
    with torch.no_grad():
        generated = generator(Params())
        
        if show_crop:
            print("Augmented cropped example")
            aug_gen = generated.float()
            aug_gen = create_crops(aug_gen, num_crops=1)
            aug_gen_norm = norm_data(aug_gen[0])
            show_from_tensor(aug_gen_norm)
            
        print("Generation")
        latest_gen = norm_data(generated.cpu())
        show_from_tensor(latest_gen[0])
        
    return (latest_gen[0])

In [None]:
#Optimization Process
def optimize_result(Params, prompt,extras_enc, exclude_enc):
    alpha, beta = 1, 0.5 #alpha: importance of include encodings beta:importance of exclude
    
    #image encoding
    out = generator(Params())
    out = norm_data(out)
    out = create_crops(out)
    out = normalize(out)
    image_enc = clipmodel.encode_image(out)
    
    #text encoding
    final_enc = w1*prompt + w1*extras_enc #1 x 512
    final_text_include_enc = final_enc / final_enc.norm(dim=1, keepdim=True)
    final_text_exclude_enc = exclude_enc
    
    #Calculate the loss
    main_loss = torch.cosine_similarity(final_text_include_enc, image_enc, -1)# 30
    penalize_loss = torch.cosine_similarity(final_text_exclude_enc, image_enc, -1) #30
    
    final_loss = -alpha*main_loss + beta*penalize_loss
    
    return final_loss
    
def optimize(Params, optimizer, prompt, extras_enc, exclude_enc):
    loss = optimize_result(Params,prompt, extras_enc, exclude_enc).mean()
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss

In [None]:
def training_loop(Params, optimizer, include_enc, exclude_enc, extras_enc, show_step=(total_iteration-1), show_crop=False):
    res_img=[]
    res_z=[]
    
    for prompt in include_enc:
        iteration=0
        # Params, optimizer = init_params() #1 x 256 x 14 x 25
        
        for i in range(total_iteration):
            loss = optimize(Params, optimizer, prompt, exclude_enc, extras_enc)
            
            if iteration >= (total_iteration//2) and iteration % show_step == 0:
                new_img = showme(Params, show_crop)
                res_img.append(new_img)
                res_z.append(Params())
                print("loss:",loss.item(), "\niteration:", iteration)
                
            iteration+=1
        torch.cuda.empty_cache()
    return res_img, res_z

In [None]:
torch.cuda.empty_cache()
include = ["sketch of a lady", "sketch of a man on a horse"]
exclude = ' watermark, cropped, confusing, incoherent, cut, blurry'
extras = "watercolor paper texture"
w1 = 1
w2 = 1

include_enc, exclude_enc, extras_enc = createEncodings(include, exclude, extras)
res_img, res_z = training_loop(Params, optimizer, include_enc, exclude_enc, extras_enc, show_crop=True)


In [None]:
for prompt, img in zip(include, res_img):
    save_tensor(img,(prompt+".png"))

In [None]:
print(len(res_img), len(res_z))
print(res_img[0].shape, res_z[0].shape)
print(res_z[0].max(), res_z[0].min())

In [None]:
def interpolate(res_z_list, duration_list):
    gen_img_list=[]
    fps=25
    
    for idx, (z, duration) in enumerate(res_z_list, duration_list):
        num_steps = int(duration*fps)
        z1 = z
        z2 = res_z_list[(idx+1)%len(res_z_list)]
        
        for step in range(num_steps):
            alpha = math.sin(1.5*step/num_steps)**6
            z_new = alpha *z2 + (1-alpha)*z1
            
            new_gen = norm_data(generator(z_new).cpu())[0]
            new_img = torchvision.transforms.ToPILImage(mode='RGB')(new_gen)
            gen_img_list.append(new_img)
            
    return gen_img_list

durations = [3,3,3,3,3,3]
interpolate_results_images = interpolate(res_z, durations)

In [None]:
out_video_path=f"res1.mp4"
writer = imageio.get_writer(out_video_path, fps=25)
for pil_img in interpolate_results_images:
    img = np.array(pil_img, dtype=np.uint8)
    writer.append_data(img)
    
writer.close()

In [None]:
from IPython.display import HTML
from base64 import b64encode
mp4 = open(f"res1.mp4", "rb").read()
data = "data:video/mp4;base64,"+b64encode(mp4).decode()
HTML("""<video width=800 controls><source src="%s" type="video/mp4"></video>"""%data)