In [1]:
import os
import gc
import torch
import numpy as np

from PIL import Image
from torch import autocast
from einops import rearrange
from importlib import import_module
from omegaconf import OmegaConf
from tqdm.auto import tqdm, trange
from torchvision.utils import make_grid
tqdm.pandas()

from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from ldm.models.diffusion.dpm_solver import DPMSolverSampler

In [2]:
torch.backends.cudnn.benchmark = True
gpu = torch.device("cuda")
cpu = torch.device('cpu')

In [3]:
# pl_sd = torch.load('./model/ckpt/v2-1_512-nonema-pruned.ckpt')
# pl_sd = torch.load('./model/ckpt/wd-1-4-anime_e2.ckpt')
pl_sd = torch.load('./model/ckpt/trinart2_step115000.ckpt')
sd = pl_sd['state_dict']

In [6]:
# config = OmegaConf.load('./model/yaml/v2-inference.yaml')
# config = OmegaConf.load('./model/yaml/wd-1-4-anime_e1.yaml')
# config = OmegaConf.load('./model/yaml/v2-inference-v.yaml')
config = OmegaConf.load('./model/yaml/v2-inference-inpainting.yaml')
module, _cls = config.model.target.rsplit(".", 1) 
module = import_module(module)
model = getattr(module, _cls)(**config.model.params)

LatentInpaintDiffusion: Running in eps-prediction mode
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is None and using 5 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is 1024 and using 5 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is None and using 5 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is 1024 and using 5 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is None and using 10 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is 1024 and using 10 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is None and using 10 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is 1024 and using 10 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is None and using 20 heads.
Setting up MemoryEfficientCrossAttention. Query 

In [7]:
verbose = True
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
    print("missing keys:")
    print(m)
if len(u) > 0 and verbose:
    print("unexpected keys:")
    print(u)

In [None]:
model.cuda()
model.eval()
model = model.to(gpu, dtype=torch.float16)

In [None]:
class Option:
    steps = 30
    n_samples = 1
    ddim_eta = 0.0
    C = 4
    H = 512
    W = 512
    f = 8
    scale = 7.5

opt = Option()
def txt2img(prompt, n_iter=1):
    return_image_list = list()
    data = [opt.n_samples * [prompt]]
    
    # sampler = PLMSSampler(model)
    # sampler = DPMSolverSampler(model)
    sampler = DDIMSampler(model)
    
    with torch.no_grad(), \
        autocast(enabled=True, dtype=torch.float16, device_type='cuda'), \
        model.ema_scope():
            all_samples = list()
            for n in range(n_iter):
                for prompts in data:
                    uc = None
                    if opt.scale != 1.0:
                        uc = model.get_learned_conditioning(opt.n_samples * [""])
                    if isinstance(prompts, tuple):
                        prompts = list(prompts)
                    c = model.get_learned_conditioning(prompts)
                    shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
                    samples, _ = sampler.sample(S=opt.steps,
                                                conditioning=c,
                                                batch_size=opt.n_samples,
                                                shape=shape,
                                                verbose=False,
                                                unconditional_guidance_scale=opt.scale,
                                                unconditional_conditioning=uc,
                                                eta=opt.ddim_eta,
    #                                             x_T=start_code
                                               )

                    gc.collect()
                    torch.cuda.empty_cache()

                    x_samples = model.decode_first_stage(samples)
                    x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)

                    for idx, x_sample in enumerate(x_samples):
                        x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
                        img = Image.fromarray(x_sample.astype(np.uint8))
                        return_image_list.append(img)
    return return_image_list

In [None]:
prompt = ' A astronaut riding a horse on the moon'

images = txt2img(prompt)

In [None]:
images[0]