In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
import os

GOOGLE_DRIVE_PATH_AFTER_MYDRIVE = "Raffaello_Sanzio/styleid"
GOOGLE_DRIVE_PATH = os.path.join('/content/drive', 'My Drive', GOOGLE_DRIVE_PATH_AFTER_MYDRIVE)
print(os.listdir(GOOGLE_DRIVE_PATH))

['ckpt', 'data', 'data_model', 'config', 'ldm', 'models', 'output', 'precomputed_feats', 'IMG_3724.JPG', 'IMG_3724_styled_joana-abreu-aFkzShngdaw-unsplash.png', 'lora', 'train', 'styleid.ipynb', 'train.ipynb']


In [4]:
import sys
sys.path.append(GOOGLE_DRIVE_PATH)

In [5]:
!pip install pytorch-lightning==1.4.2
!pip install omegaconf==2.1.1
!pip install torchmetrics==0.6.0
!pip install git+https://github.com/openai/CLIP.git
!pip install kornia==0.6

Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-i6a3l9o5
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-i6a3l9o5
  Resolved https://github.com/openai/CLIP.git to commit dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1
  Preparing metadata (setup.py) ... [?25l[?25hdone


In [6]:
import os
import yaml
import argparse
import time
import copy
import torch
import numpy as np
import pickle
from contextlib import nullcontext
from PIL import Image
from torchvision import transforms
from einops import rearrange
from pytorch_lightning import seed_everything
from peft import PeftModel

from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler

In [7]:
def load_model_from_config(model_config, ckpt_path, lora_path, verbose=False, load_lora=False):
    print(f"Loading model from {ckpt_path}")
    pl_sd = torch.load(ckpt_path, map_location="cpu")
    if "global_step" in pl_sd:
        print(f"Global Step: {pl_sd['global_step']}")
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(model_config["model"])
    m, u = model.load_state_dict(sd, strict=False)
    if len(m) > 0 and verbose:
        print("missing keys:")
        print(m)
    if len(u) > 0 and verbose:
        print("unexpected keys:")
        print(u)

    # load lora
    if load_lora:
        print(f"Loading LoRA from {lora_path}")
        model.model.diffusion_model = PeftModel.from_pretrained(model.model.diffusion_model, lora_path)
        # print(model.model.diffusion_model)
    model.cuda()
    model.eval()
    return model

In [8]:
def load_img(path):
    im = Image.open(path).convert("RGB")
    x, y = im.size
    print(f"Loaded input image of size ({x}, {y}) from {path}")
    h = w = 512
    im = transforms.CenterCrop(min(x,y))(im)
    im = im.resize((w,h), resample=Image.Resampling.LANCZOS)
    im = np.array(im).astype(np.float32) / 255.0
    im = im[None].transpose(0, 3, 1, 2)
    im = torch.from_numpy(im)
    return 2. * im - 1.

In [9]:
def adain(cnt_feat, sty_feat):
    cnt_mean = cnt_feat.mean(dim=[0,2,3], keepdim=True)
    cnt_std = cnt_feat.std(dim=[0,2,3], keepdim=True )
    sty_mean = sty_feat.mean(dim=[0,2,3], keepdim=True)
    sty_std = sty_feat.std(dim=[0,2,3], keepdim=True)
    return ((cnt_feat - cnt_mean) / cnt_std) * sty_std + sty_mean

In [10]:
def feat_merge(train_config, cnt_feats, sty_feats, start_step):
    feat_maps = [
        {
            "config" :{
                "gamma" : train_config["gamma"],
                "T" : train_config["T"],
            }
        } for _ in range(50)
    ]

    for i in range(len(feat_maps)):
        if i < (50 - start_step):
            continue
        cnt_feat = cnt_feats[i]
        sty_feat = sty_feats[i]
        ori_keys = sty_feat.keys()

        for ori_key in ori_keys:
            if ori_key[-1] == 'q':
                feat_maps[i][ori_key] = cnt_feat[ori_key]
            if ori_key[-1] == 'k' or ori_key[-1] == 'v':
                feat_maps[i][ori_key] = sty_feat[ori_key]
    return feat_maps


In [11]:
def train(args):
    with open(args.config_path) as f:
        config = yaml.safe_load(f)

    train_config = config["train_params"]

    seed = train_config["seed"]
    seed_everything(22)
    feat_path_root = train_config["precomputed"]

    output_path = os.path.join(GOOGLE_DRIVE_PATH, train_config["output_path"])
    os.makedirs(output_path, exist_ok=True)
    if len(feat_path_root):
        os.makedirs(os.path.join(GOOGLE_DRIVE_PATH, feat_path_root), exist_ok=True)

    with open(os.path.join(GOOGLE_DRIVE_PATH, train_config["model_config"])) as f:
        model_config = yaml.safe_load(f)

    ckpt_path = os.path.join(GOOGLE_DRIVE_PATH, train_config["ckpt"])
    lora_path = os.path.join(GOOGLE_DRIVE_PATH, "lora", train_config["lora_ckpt_name"])
    use_lora = False

    model = load_model_from_config(model_config, ckpt_path, lora_path, load_lora=use_lora)

    self_attn_output_block_indices = list(map(int, train_config["attn_layer"].split(',')))
    ddim_inversion_steps = train_config["ddim_inv_steps"]
    save_feature_timesteps = ddim_steps = train_config["save_feat_steps"]

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    unet_model = model.model.diffusion_model

    # scheduler
    sampler = DDIMSampler(model)
    sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=train_config["ddim_eta"], verbose=False)
    time_range = np.flip(sampler.ddim_timesteps)
    idx_time_dict = {}
    time_idx_dict = {}
    for i, t in enumerate(time_range):
        idx_time_dict[t] = i
        time_idx_dict[i] = t

    seed = torch.initial_seed()
    train_config["seed"] = seed
    print(f"Init Seed: {seed}")

    global feat_maps
    feat_maps = [
        {
            "config" :{
                "gamma" : train_config["gamma"],
                "T" : train_config["T"],
            }
        } for _ in range(50)
    ]

    def ddim_sampler_callback(pred_x0, xt, i):
        save_feature_maps_callback(i)
        save_feature_map(xt, 'z_enc', i)

    def save_feature_maps_callback(i):
        save_feature_maps(unet_model.output_blocks, i, "output_block")

    def save_feature_maps(blocks, i, feature_type="input_block"):
        block_idx = 0
        for block_idx, block in enumerate(blocks):
            if len(block) > 1 and "SpatialTransformer" in str(type(block[1])):
                if block_idx in self_attn_output_block_indices:
                    q = block[1].transformer_blocks[0].attn1.q
                    k = block[1].transformer_blocks[0].attn1.k
                    v = block[1].transformer_blocks[0].attn1.v
                    save_feature_map(q, f"{feature_type}_{block_idx}_self_attn_q", i)
                    save_feature_map(k, f"{feature_type}_{block_idx}_self_attn_k", i)
                    save_feature_map(v, f"{feature_type}_{block_idx}_self_attn_v", i)

    def save_feature_map(feature_map, filename, time):
        global feat_maps
        cur_idx = idx_time_dict[time]
        feat_maps[cur_idx][f"{filename}"] = feature_map

    start_step = train_config["start_step"]
    precision_scope = torch.autocast if train_config["precision"] == "autocast" else nullcontext

    c = model.get_learned_conditioning(["Raffaello Sanzio Painting"]) if use_lora else None
    # c = None
    uc = model.get_learned_conditioning([""])
    shape = [train_config['C'], train_config['H'] // train_config['f'], train_config['W'] // train_config['f']]
    sty_img_list = sorted(os.listdir(os.path.join(GOOGLE_DRIVE_PATH, train_config['sty'])))
    cnt_img_list = sorted(os.listdir(os.path.join(GOOGLE_DRIVE_PATH, train_config['cnt'])))

    begin = time.time()

    for sty_name in sty_img_list:
        sty_name_ = os.path.join(GOOGLE_DRIVE_PATH, train_config['sty'], sty_name)
        init_sty = load_img(sty_name_).to(device)
        seed = -1
        sty_title = os.path.basename(sty_name).split('.')[0]
        sty_feat_name = os.path.join(GOOGLE_DRIVE_PATH, feat_path_root,
                                     sty_title + "_sty.pkl")
        sty_z_enc = None

        if len(feat_path_root) > 0 and os.path.isfile(sty_feat_name):
            print("precomputed style Feature found and loading:", sty_feat_name)
            with open(sty_feat_name, 'rb') as h:
                sty_feat = pickle.load(f)
                sty_z_enc = torch.clone(sty_feat[0]['z_enc'])
        else:
            init_sty = model.get_first_stage_encoding(model.encode_first_stage(init_sty))
            sty_z_enc, _ = sampler.encode_ddim(
                init_sty.clone(),
                num_steps=ddim_inversion_steps,
                conditioning=c,
                unconditional_conditioning=uc,
                end_step=time_idx_dict[ddim_inversion_steps-1-start_step],
                callback_ddim_timesteps=save_feature_timesteps,
                img_callback=ddim_sampler_callback
            )
            sty_feat = copy.deepcopy(feat_maps)
            sty_z_enc = feat_maps[0]['z_enc']

        for cnt_name in cnt_img_list:
            cnt_name_ = os.path.join(GOOGLE_DRIVE_PATH, train_config['cnt'], cnt_name)
            init_cnt = load_img(cnt_name_).to(device)

            cnt_title = os.path.basename(cnt_name).split('.')[0]
            cnt_feat_name = os.path.join(GOOGLE_DRIVE_PATH, feat_path_root,
                                         cnt_title + "_cnt.pkl")
            cnt_z_enc = None

            if len(feat_path_root) > 0 and os.path.isfile(cnt_feat_name):
                print("Precomputed content feature loading: ", cnt_feat_name)
                with open(cnt_feat_name, 'rb') as h:
                    cnt_feat = pickle.load(h)
                    cnt_z_enc = torch.clone(cnt_feat[0]['z_enc'])
            else:
                init_cnt = model.get_first_stage_encoding(model.encode_first_stage(init_cnt))
                cnt_z_enc, _ = sampler.encode_ddim(
                    init_cnt.clone(),
                    num_steps=ddim_inversion_steps,
                    conditioning=c,
                    unconditional_conditioning=uc,
                    end_step=time_idx_dict[ddim_inversion_steps-1-start_step],
                    callback_ddim_timesteps=save_feature_timesteps,
                    img_callback=ddim_sampler_callback)
                cnt_feat = copy.deepcopy(feat_maps)
                cnt_z_enc = feat_maps[0]['z_enc']


            with torch.no_grad():
                with precision_scope("cuda"):
                    with model.ema_scope():
                        output_name = f"{cnt_title}_styled_{sty_title}.png"
                        print(f"Inversion end: {time.time() - begin}")
                        print(train_config["without_init_adain"], train_config["without_init_adain"])
                        if train_config["without_init_adain"]:
                            adain_z_enc = cnt_z_enc
                        else:
                            adain_z_enc = adain(cnt_z_enc, sty_z_enc)

                        feat_maps = feat_merge(train_config, cnt_feat, sty_feat, start_step=start_step)

                        if train_config["without_attn_injection"]:
                            feat_maps = None

                        samples_ddim, _ = sampler.sample(
                            S=ddim_steps,
                            batch_size=1,
                            shape=shape,
                            verbose=False,
                            conditioning=c,
                            unconditional_conditioning=uc,
                            eta=train_config["ddim_eta"],
                            x_T=adain_z_enc,
                            injected_features=feat_maps,
                            start_step=start_step
                        )

                        # x_samples_ddim = model.decode_first_stage(samples_ddim)
                        # x_samples_ddim = torch.clamp((x_samples_ddim + 1) / 2 , -1., 1.)
                        # x_samples_ddim = x_samples_ddim.cpu().permute(0,2,3,1).numpy()
                        # x_sample = 255 * x_samples_ddim[0]
                        # im = Image.fromarray(x_sample.astype(np.uint8))
                        # im.save(os.path.join(output_path, output_name))
                        x_samples_ddim = model.decode_first_stage(samples_ddim)
                        x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
                        x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy()
                        x_image_torch = torch.from_numpy(x_samples_ddim).permute(0, 3, 1, 2)
                        x_sample = 255. * rearrange(x_image_torch[0].cpu().numpy(), 'c h w -> h w c')
                        img = Image.fromarray(x_sample.astype(np.uint8))

                        img.save(os.path.join(output_path, output_name))
parser = argparse.ArgumentParser()


In [12]:
args = argparse.Namespace(config_path=os.path.join(GOOGLE_DRIVE_PATH, 'config/style.yaml'))
train(args)

INFO:pytorch_lightning.utilities.seed:Global seed set to 22
  pl_sd = torch.load(ckpt_path, map_location="cpu")


Loading model from /content/drive/My Drive/Raffaello_Sanzio/styleid/models/ldm/stable-diffusion-v1/model.ckpt
Global Step: 470000
LatentDiffusion: Running in eps-prediction mode
DiffusionWrapper has 859.52 M params.
making attention of type 'vanilla' with 512 in_channels
Working with z of shape (1, 4, 32, 32) = 4096 dimensions.
making attention of type 'vanilla' with 512 in_channels


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Init Seed: 22
Loaded input image of size (1920, 1038) from /content/drive/My Drive/Raffaello_Sanzio/styleid/data/sty/howl003.jpg
Running DDIM inversion with 50 timesteps


DDIM Inversion:   0%|          | 0/50 [00:00<?, ?it/s]

Selected timesteps for ddim sampler: [  1  21  41  61  81 101 121 141 161 181 201 221 241 261 281 301 321 341
 361 381 401 421 441 461 481 501 521 541 561 581 601 621 641 661 681 701
 721 741 761 781 801 821 841 861 881 901 921 941 961 981]


DDIM Inversion: 100%|██████████| 50/50 [00:08<00:00,  5.78it/s]


Loaded input image of size (1200, 630) from /content/drive/My Drive/Raffaello_Sanzio/styleid/data/cnt/one.jpg
Running DDIM inversion with 50 timesteps


DDIM Inversion:   0%|          | 0/50 [00:00<?, ?it/s]

Selected timesteps for ddim sampler: [  1  21  41  61  81 101 121 141 161 181 201 221 241 261 281 301 321 341
 361 381 401 421 441 461 481 501 521 541 561 581 601 621 641 661 681 701
 721 741 761 781 801 821 841 861 881 901 921 941 961 981]


DDIM Inversion: 100%|██████████| 50/50 [00:08<00:00,  5.81it/s]


Inversion end: 17.834339141845703
False False
Data shape for DDIM sampling is (1, 4, 64, 64), eta 0.0
Running DDIM Sampling with 50 timesteps


DDIM Sampler: 100%|██████████| 50/50 [00:05<00:00,  8.84it/s]


Loaded input image of size (1920, 1038) from /content/drive/My Drive/Raffaello_Sanzio/styleid/data/sty/majo050.jpg
Running DDIM inversion with 50 timesteps


DDIM Inversion:   0%|          | 0/50 [00:00<?, ?it/s]

Selected timesteps for ddim sampler: [  1  21  41  61  81 101 121 141 161 181 201 221 241 261 281 301 321 341
 361 381 401 421 441 461 481 501 521 541 561 581 601 621 641 661 681 701
 721 741 761 781 801 821 841 861 881 901 921 941 961 981]


DDIM Inversion: 100%|██████████| 50/50 [00:08<00:00,  5.74it/s]


Loaded input image of size (1200, 630) from /content/drive/My Drive/Raffaello_Sanzio/styleid/data/cnt/one.jpg
Running DDIM inversion with 50 timesteps


DDIM Inversion:   0%|          | 0/50 [00:00<?, ?it/s]

Selected timesteps for ddim sampler: [  1  21  41  61  81 101 121 141 161 181 201 221 241 261 281 301 321 341
 361 381 401 421 441 461 481 501 521 541 561 581 601 621 641 661 681 701
 721 741 761 781 801 821 841 861 881 901 921 941 961 981]


DDIM Inversion: 100%|██████████| 50/50 [00:08<00:00,  5.72it/s]


Inversion end: 41.708003520965576
False False
Data shape for DDIM sampling is (1, 4, 64, 64), eta 0.0
Running DDIM Sampling with 50 timesteps


DDIM Sampler: 100%|██████████| 50/50 [00:05<00:00,  9.38it/s]


Loaded input image of size (1920, 1038) from /content/drive/My Drive/Raffaello_Sanzio/styleid/data/sty/ponyo033.jpg
Running DDIM inversion with 50 timesteps


DDIM Inversion:   0%|          | 0/50 [00:00<?, ?it/s]

Selected timesteps for ddim sampler: [  1  21  41  61  81 101 121 141 161 181 201 221 241 261 281 301 321 341
 361 381 401 421 441 461 481 501 521 541 561 581 601 621 641 661 681 701
 721 741 761 781 801 821 841 861 881 901 921 941 961 981]


DDIM Inversion: 100%|██████████| 50/50 [00:08<00:00,  5.67it/s]


Loaded input image of size (1200, 630) from /content/drive/My Drive/Raffaello_Sanzio/styleid/data/cnt/one.jpg
Running DDIM inversion with 50 timesteps


DDIM Inversion:   0%|          | 0/50 [00:00<?, ?it/s]

Selected timesteps for ddim sampler: [  1  21  41  61  81 101 121 141 161 181 201 221 241 261 281 301 321 341
 361 381 401 421 441 461 481 501 521 541 561 581 601 621 641 661 681 701
 721 741 761 781 801 821 841 861 881 901 921 941 961 981]


DDIM Inversion: 100%|██████████| 50/50 [00:08<00:00,  5.67it/s]


Inversion end: 65.3847188949585
False False
Data shape for DDIM sampling is (1, 4, 64, 64), eta 0.0
Running DDIM Sampling with 50 timesteps


DDIM Sampler: 100%|██████████| 50/50 [00:05<00:00,  9.38it/s]


In [13]:
"""
    testing cases
        1. With different text emb during training
        2. with out conditioning
        3.
"""

'\n    testing cases\n        1. With different text emb during training\n        2. with out conditioning\n        3.\n'