In [None]:
import argparse, os
import cv2
import torch
import numpy as np
from omegaconf import OmegaConf
from PIL import Image
from tqdm import tqdm, trange
from itertools import islice
from einops import rearrange
from torchvision.utils import make_grid
from pytorch_lightning import seed_everything
from torch.cuda.amp import autocast
from contextlib import nullcontext
from imwatermark import WatermarkEncoder

import sys
sys.path.insert(0, '/share/project/yfl/codebase/git/AltTools/Altdiffusion/src')
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from ldm.models.diffusion.dpm_solver import DPMSolverSampler

torch.set_grad_enabled(False)

def load_model_from_config(config, ckpt, use_ema, verbose=False):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location="cpu")
    if "global_step" in pl_sd:
        print(f"Global Step: {pl_sd['global_step']}")
    
    if use_ema:
        sd = pl_sd["state_dict_ema"]
    else:
        sd = pl_sd["state_dict"]
    # 模型是在这个地方初始化，初始化的
    model = instantiate_from_config(config.model)

    m, u = model.load_state_dict(sd, strict=False)
    if len(m) > 0 and verbose:
        print("missing keys:")
        print(m)
    if len(u) > 0 and verbose:
        print("unexpected keys:")
        print(u)

    model.cuda()
    model.eval()
    return model

class OPT():
    def __init__(self) -> None:
        pass

opt = OPT()
opt.prompt = "一个中国小男孩"
opt.steps = 50
opt.ddim_eta = 0.0
opt.n_iter = 1
opt.H=512
opt.W=512
opt.C = 4
opt.f = 8
opt.n_samples = 4
opt.n_rows = 0
opt.scale = 9.0
opt.plms = False
opt.dpm = False
opt.fixed_code = False
opt.precision = 'autocast'
opt.use_ema = False

opt.config = "/share/project/yfl/codebase/git/AltTools/Altdiffusion/src/configs/v2-inference-alt.yaml"
# opt.ckpt = "/share/project/yfl/codebase/stable_diffusion_2.0/stablediffusion/logs/train_2.1_small_lr/checkpoints/epoch=000000.ckpt"
# opt.ckpt = "/share/project/yfl/codebase/stable_diffusion_2.0/ckpt/v2-1_512-ema-pruned.ckpt"
# opt.ckpt = "/share/project/yfl/database/stable_diffusion_2.0/512-base-ema.ckpt"
# opt.ckpt = "/share/project/yfl/codebase/stable_diffusion_2.0/stablediffusion/logs/train_2.1_train_use_penultimate/checkpoints/epoch=000000.ckpt"
# opt.ckpt = "/share/project/yfl/database/stable_diffusion_2.0/512-base-ema.ckpt"
# opt.ckpt = "/share/project/yfl/codebase/stable_diffusion_2.0/stablediffusion/logs/train_2.1_train_use_penultimate_large_lr/checkpoints/epoch=000002.ckpt"
# opt.ckpt = "/share/project/yfl/codebase/stable_diffusion_2.0/stablediffusion/logs/train_2.1_train_use_penultimate_large_lr_1e06/checkpoints/epoch=000002.ckpt"
# opt.ckpt = "/share/project/yfl/codebase/stable_diffusion_2.0/stablediffusion/logs/train_2.1_train_use_penultimate_large_lr_1e05/checkpoints/epoch=000001.ckpt"
# opt.ckpt = "/share/project/yfl/codebase/stable_diffusion_2.0/stablediffusion/logs/test_lr/checkpoints/step=00050000.ckpt"
# opt.ckpt = "/share/project/yfl/codebase/stable_diffusion_2.0/stablediffusion/logs/m18/checkpoints/epoch=000000.ckpt"
# opt.ckpt = "/share/project/yfl/codebase/git/AltTools/Altdiffusion/ckpt/laion5plus_real/checkpoints/step=000450000.ckpt"
# opt.ckpt = "/share/project/liuguang/alt_ckpts/step=000180000.ckpt"
# opt.ckpt = "/share/project/yfl/codebase/git/AltTools/Altdiffusion/ckpt/laion5plus_256_kv/checkpoints/step=000030000.ckpt"
# opt.ckpt = '/share/project/yfl/codebase/git/AltTools/Altdiffusion/ckpt/ema_xformer_laion6plus_512_all_new_v2/checkpoints/step=000007500.ckpt'
# opt.ckpt = '/share/project/yfl/codebase/git/AltTools/Altdiffusion/ckpt/laion5plus_256_kv/checkpoints/step=000030000.ckpt'
# opt.ckpt = "/share/project/yfl/codebase/stable_diffusion_2.0/stablediffusion/logs/test_lr/checkpoints/step=00037441.ckpt"
# opt.ckpt = '/share/project/yfl/codebase/git/AltTools/Altdiffusion/ckpt/xformer_laion5plus_512_kv_cfg/checkpoints/step=000015000.ckpt'
opt.ckpt = '/share/project/yfl/codebase/git/AltTools/Altdiffusion/ckpt/ckpt_20230315/step=000330000.ckpt'
# opt.ckpt = '/share/project/yfl/database/ckpt/aethetics_all_ema_cfg/step=000025000.ckpt'

config = OmegaConf.load(f"{opt.config}")
model = load_model_from_config(config, f"{opt.ckpt}", opt.use_ema)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)

if opt.plms:
    sampler = PLMSSampler(model)
elif opt.dpm:
    sampler = DPMSolverSampler(model)
else:
    sampler = DDIMSampler(model)


def generate(prompt, negative_prompt, seed, opt):
    seed_everything(seed)
    batch_size = opt.n_samples
    n_rows = opt.n_rows if opt.n_rows > 0 else batch_size

    start_code = None
    if opt.fixed_code:
        start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)

    precision_scope = autocast if opt.precision == "autocast" else nullcontext
    with torch.no_grad(), \
        precision_scope(True), \
        model.ema_scope():
            all_samples = list()
            prompts = [batch_size * [prompt]]
            for prompts in tqdm(prompts, desc="data"):
                uc = None
                if opt.scale != 1.0:
                    uc = model.get_learned_conditioning(batch_size * [negative_prompt])
                if isinstance(prompts, tuple):
                    prompts = list(prompts)
                c = model.get_learned_conditioning(prompts)
                shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
                samples, _ = sampler.sample(S=opt.steps,
                                                    conditioning=c,
                                                    batch_size=opt.n_samples,
                                                    shape=shape,
                                                    verbose=False,
                                                    unconditional_guidance_scale=opt.scale,
                                                    unconditional_conditioning=uc,
                                                    eta=opt.ddim_eta,
                                                    x_T=start_code)

                x_samples = model.decode_first_stage(samples)
                x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)

                for x_sample in x_samples:
                    x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
                    img = Image.fromarray(x_sample.astype(np.uint8))

                all_samples.append(x_samples)

            # additionally, save as grid
            grid = torch.stack(all_samples, 0)
            grid = rearrange(grid, 'n b c h w -> (n b) c h w')
            grid = make_grid(grid, nrow=n_rows)

            # to image
            grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
            grid = Image.fromarray(grid.astype(np.uint8))
            grid.show()

In [None]:
# prompt = "A Pikachu"
# prompt = "一个穿着 EVA 插头套装的漂亮女孩的超写实绘画，超详细，动漫，作者 greg rutkowski，在 artstation 上流行"
# prompt = "ジャケット、ビクトリア朝、コンセプト アート、詳細な顔、ファンタジー、顔のクローズ アップ、非常に詳細な、映画のような照明、グレッグ rutkowski によるデジタル アートの絵画で頑丈な 19 世紀の男の肖像"
# prompt = "Hyper realistic painting of a beautiful girl in an EVA plugsuit, hyper detailed, anime, by greg rutkowski, trending on artstation"
# prompt = "Pikachu commiting tax fraud, paperwork, exhausted, cute, really cute, cozy,by steve hanks, by lisa yuskavage, by serov valentin, by tarkovsky, 8 k render, detailed, cute cartoon style"
negative_prompt = "nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry"
# negative_prompt = ''
prompt = "a lively magical town inspired by victorian england and amsterdam, sunny weather, highly detailed, intricate, digital painting, trending on artstation, concept art, matte painting, art by greg rutkwowski, craig mullins, octane render, 8 k, unreal engine"
# negative_prompt = "low quality"
# negative_prompt = "low quality"
seed = 455561
opt.W = 512
opt.H = 768
generate(prompt,negative_prompt,seed, opt)