# Welcome to SINE: SINgle Image Editing with Text-to-Image Diffusion Models!

In [None]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
    print('Not connected to a GPU')
else:
    print(gpu_info)

from psutil import virtual_memory
ram_gb = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

if ram_gb < 20:
    print('Not using a high-RAM runtime')
else:
    print('You are using a high-RAM runtime!')

# Step 1: Setup required libraries and models. 
This may take a few minutes.

You may optionally enable downloads with pydrive in order to authenticate and avoid drive download limits when fetching the pre-trained model.

In [None]:
#@title Setup

import os

from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials

from argparse import Namespace

import sys
import numpy as np

from PIL import Image

import torch
import torchvision.transforms as transforms

device = 'cuda'


# install requirements
!git clone https://github.com/zhang-zx/SINE.git sine_dir

%cd sine_dir/
!pip uninstall -y torchtext
! pip install transformers==4.18.0 einops==0.4.1 omegaconf==2.1.1 torchmetrics==0.6.0 torch-fidelity==0.3.0 kornia==0.6 albumentations==1.1.0 opencv-python==4.2.0.34 imageio==2.14.1 setuptools==59.5.0 pillow==9.0.1 
! pip install torch==1.10.2 torchvision==0.11.3
! pip install pytorch-lightning==1.5.9
! pip install git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
! pip install git+https://github.com/openai/CLIP.git@main#egg=clip
! pip install -e .



download_with_pydrive = True 
    
class Downloader(object):
    def __init__(self, use_pydrive):
        self.use_pydrive = use_pydrive

        if self.use_pydrive:
            self.authenticate()
        
    def authenticate(self):
        auth.authenticate_user()
        gauth = GoogleAuth()
        gauth.credentials = GoogleCredentials.get_application_default()
        self.drive = GoogleDrive(gauth)
    
    def download_file(self, file_id, file_dst):
        if self.use_pydrive:
            downloaded = self.drive.CreateFile({'id':file_id})
            downloaded.FetchMetadata(fetch_all=True)
            downloaded.GetContentFile(file_dst)
        else:
            !gdown --id $file_id -O $file_dst

downloader = Downloader(download_with_pydrive)

pre_trained_path = os.path.join('models', 'ldm', 'stable-diffusion-v4')
os.makedirs(pre_trained_path, exist_ok=True)
!wget -O models/ldm/stable-diffusion-v4/sd-v1-4-full-ema.ckpt https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4-full-ema.ckpt 



# Step 2: Download the selected fine-tuned model. 

In [None]:
finetuned_models_dir = os.path.join('./models', 'finetuned')
os.makedirs(finetuned_models_dir, exist_ok=True)

orig_image_dir = './dataset'
os.makedirs(orig_image_dir, exist_ok=True)

source_model_type = 'dog w/o patch-based fine-tuning' #@param['dog w/o patch-based fine-tuning', 'dog w/ patch-based fine-tuning', 'Girl with a peral earring', 'Monalisa', 'castle w/o patch-based fine-tuning', 'castle w/ patch-based fine-tuning']
source_model_download_path = {"dog w/o patch-based fine-tuning":   "1jHgkyxrwUXyMR2zBK9WAEWioEP3-F3fd",
                              "dog w/ patch-based fine-tuning":    "1YI7c29qBIy83OqJ4ykoAAul6P8uaXmls",
                              "Girl with a peral earring":    "1l6GCEfyURKQiCF77ZriYoZRtkOXQCyWD",
                              "Monalisa": "194CDgHkomKrLvgFj89kamoTjUbMwyUIC",
                              "castle w/o patch-based fine-tuning":    "19I8ftab9vMQWnqPH2O7aHe-GnYolmVFF",
                              "castle w/ patch-based fine-tuning":  "1srzUr1fg6jTFKuf0M5oi5JgBsVhCt-nb"}

model_names = { "dog w/o patch-based fine-tuning":   "dog_wo_patch.ckpt",
                "dog w/ patch-based fine-tuning":    "dog_w_patch.ckpt",
                "Girl with a peral earring":    "girl.ckpt",
                "Monalisa": "monalisa.ckpt",
                "castle w/o patch-based fine-tuning":    "castle_wo_patch",
                "castle w/ patch-based fine-tuning":  "castle_w_patch"}

model_configs = { "dog w/o patch-based fine-tuning":   "./configs/stable-diffusion/v1-inference.yaml",
                "dog w/ patch-based fine-tuning":    "./configs/stable-diffusion/v1-inference_patch.yaml",
                "Girl with a peral earring":    "./configs/stable-diffusion/v1-inference_patch_nearest.yaml",
                "Monalisa": "./configs/stable-diffusion/v1-inference_patch_nearest.yaml",
                "castle w/o patch-based fine-tuning":    "./configs/stable-diffusion/v1-inference.yaml",
                "castle w/ patch-based fine-tuning":  "./configs/stable-diffusion/v1-inference_patch.yaml"}

orig_prompts = { "dog w/o patch-based fine-tuning":   "picture of a sks dog",
                "dog w/ patch-based fine-tuning":    "picture of a sks dog",
                "Girl with a peral earring":    "painting of a sks girl",
                "Monalisa": "painting of a sks lady",
                "castle w/o patch-based fine-tuning":    "picture of a sks castle",
                "castle w/ patch-based fine-tuning":  "picture of a sks castle"}

download_string = source_model_download_path[source_model_type]
file_name = model_names[source_model_type]

config_name = model_configs[source_model_type]
fine_tune_prompt = orig_prompts[source_model_type]

if not os.path.isfile(os.path.join(finetuned_models_dir, file_name)):
    downloader.download_file(download_string, os.path.join(finetuned_models_dir, file_name))

# Step3: Edit the image with model-based guidance

In [None]:
import argparse, os, sys, glob
import torch
import numpy as np
from omegaconf import OmegaConf
from PIL import Image
from tqdm import tqdm, trange
from itertools import islice
from einops import rearrange
from torchvision.utils import make_grid, save_image
import time
from pytorch_lightning import seed_everything
from torch import autocast
from contextlib import contextmanager, nullcontext

from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from IPython.display import display


def chunk(it, size):
    it = iter(it)
    return iter(lambda: tuple(islice(it, size)), ())


def load_model_from_config(config, ckpt, verbose=False):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location="cpu")
    if "global_step" in pl_sd:
        print(f"Global Step: {pl_sd['global_step']}")
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    if len(m) > 0 and verbose:
        print("missing keys:")
        print(m)
    if len(u) > 0 and verbose:
        print("unexpected keys:")
        print(u)

    model.cuda()
    model.eval()
    return model

seed = 42
config = OmegaConf.load('configs/stable-diffusion/v1-inference.yaml')
model = load_model_from_config(config, 'models/ldm/stable-diffusion-v4/sd-v1-4-full-ema.ckpt')

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)

sin_config = OmegaConf.load(f"{config_name}")
sin_model = load_model_from_config(config, os.path.join(finetuned_models_dir, file_name))
sin_model = sin_model.to(device)


In [None]:
v = 0.7 #@param {type:"slider", min:0, max:1, step:0.05}
K_min = 400 #@param {type:"slider", min:0, max:1000, step:10}
scale = 7.5 #@param {type:"slider", min:1.0, max:50, step:0.5}
ddim_steps = 100
ddim_eta = 0.
H = 512
W = 512

prompt = "a dog wearing a superhero cape" #@param {'type': 'string'}

extra_config = {
    'cond_beta': v,
    'cond_beta_sin': 1. - v,
    'range_t_max': 1000,
    'range_t_min': K_min
}


from ldm.models.diffusion.guidance_ddim import DDIMSinSampler
sampler = DDIMSinSampler(model, sin_model)

setattr(sampler.model, 'extra_config', extra_config)


batch_size = 1
n_rows = 2
start_code = None
precision_scope = autocast
num_samples = 4

all_samples = list()

with torch.no_grad():
    with precision_scope("cuda"):
        with model.ema_scope():
            with sin_model.ema_scope():
                tic = time.time()
                all_samples = list()
                for n in trange(num_samples, desc="Sampling"):   
                    uc = None
                    if scale != 1.0:
                        uc = model.get_learned_conditioning(batch_size * [""])
                        uc_sin = sin_model.get_learned_conditioning(batch_size * [""])

                    prompts = [prompt] * batch_size
                    prompts_single = [fine_tune_prompt] * batch_size
                    
                    c = model.get_learned_conditioning(prompts)
                    c_sin = sin_model.get_learned_conditioning(prompts_single)
                    
                    shape = [4, H // 8, W // 8]
                    samples_ddim, _ = sampler.sample( S=ddim_steps,
                                                      conditioning=c,
                                                      conditioning_single=c_sin,
                                                      batch_size=batch_size,
                                                      shape=shape,
                                                      verbose=False,
                                                      unconditional_guidance_scale=scale,
                                                      unconditional_conditioning=uc,
                                                      unconditional_conditioning_single=uc_sin,
                                                      eta=ddim_eta,
                                                      x_T=start_code)

                    x_samples_ddim = model.decode_first_stage(samples_ddim)
                    x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)

                    all_samples.append(x_samples_ddim)

                grid = torch.stack(all_samples, 0)
                grid = rearrange(grid, 'n b c h w -> (n b) c h w')
                
                grid = make_grid(grid, nrow=n_rows)

                # to image
                grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
                os.makedirs('./output', exist_ok=True)
                Image.fromarray(grid.astype(np.uint8)).save(os.path.join('./output', f'{prompt.replace(" ", "-")}.jpg'))
                display(Image.open(os.path.join('./output', f'{prompt.replace(" ", "-")}.jpg')))
