<a href="https://colab.research.google.com/github/rinnakk/japanese-stable-diffusion/blob/master/scripts/img2img.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Image2Image pipeline for Japanese Stable Diffusion

Japanese Stable Diffusion is a Japanese-specific latent text-to-image diffusion model.

This Colab notebook shows how to use Japanese Stable Diffusion using diffusers.

## License

[The CreativeML OpenRAIL M license](LICENSE)  is an [Open RAIL M license](https://www.licenses.ai/blog/2022/8/18/naming-convention-of-responsible-ai-licenses), adapted from the work that [BigScience](https://bigscience.huggingface.co/) and [the RAIL Initiative](https://www.licenses.ai/) are jointly carrying in the area of responsible AI licensing. See also [the article about the BLOOM Open RAIL license](https://bigscience.huggingface.co/blog/the-bigscience-rail-license) on which our license is based.

## 1. Set Up

In [None]:
#@title 1.1 Check GPU Status
import subprocess
try:
    nvidiasmi_output = subprocess.run(['nvidia-smi', '-L'], stdout=subprocess.PIPE).stdout.decode('utf-8')
    print(nvidiasmi_output)
except Exception as e:
    print(e)

## 2. Install packages and define necessary functions

In [None]:
try:
    from japanese_stable_diffusion import JapaneseStableDiffusionImg2ImgPipeline
except:
    res = subprocess.run(['pip', 'install', 'git+https://github.com/rinnakk/japanese-stable-diffusion'], stdout=subprocess.PIPE).stdout.decode('utf-8')
    print(res)
    from japanese_stable_diffusion import JapaneseStableDiffusionImg2ImgPipeline
import io, requests
import torch
from torch import autocast
from diffusers import DDIMScheduler
from PIL import Image
from IPython import display


def make_grid_from_pils(pil_images):
    w, h = pil_images[0].size
    grid_img = Image.new("RGB", ((len(pil_images)) * w, h))
    for idx, image in enumerate(pil_images):
        grid_img.paste(image, (idx * w, 0))
    return grid_img


def fetch(url_or_path):
    if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'):
        r = requests.get(url_or_path)
        r.raise_for_status()
        fd = io.BytesIO()
        fd.write(r.content)
        fd.seek(0)
        return fd
    return open(url_or_path, 'rb')

## 2. Load model

You need to accept the model license before downloading or using the weights. So, you'll need to visit its card, read the license and tick the checkbox if you agree.

You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens).

In [None]:
from huggingface_hub import notebook_login

notebook_login()

In [None]:
model_id = "rinna/japanese-stable-diffusion"
device = "cuda" if torch.cuda.is_available() else "cpu"

pipe = JapaneseStableDiffusionImg2ImgPipeline.from_pretrained(
    pretrained_model_name_or_path=model_id,
    torch_dtype=torch.float16,
    use_auth_token=True
).to(device)

# 4. Run!

In [None]:
#@title Do the Run!
import gradio as gr


def infer(
        prompt,
        init_image=None,
        strength=0.75,
        n_samples=4,
        guidance_scale=7.5,
        steps=50,
        width=512,
        height=512,
        seed="random",
):
    if seed == "random":
        generator = None
    else:
        generator = torch.Generator(device=device).manual_seed(int(seed))
    init_image = init_image.convert("RGB").resize((int(width), int(height)))
    with autocast(device):
        images = pipe(
            prompt=[prompt] * int(n_samples),
            init_image=init_image,
            strength=strength,
            guidance_scale=guidance_scale,
            num_inference_steps=int(steps),
            generator=generator
        )["sample"]
    return images


block = gr.Blocks(css=".container { max-width: 800px; margin: auto; }")

with block as demo:
    gr.Markdown("<h1><center>Japanese Stable Diffusion</center></h1>")
    gr.Markdown(
        "Japanese Stable Diffusion is a Japanese-specific latent text-to-image diffusion model capable of generating photo-realistic images given any text input."
    )
    with gr.Group():
        with gr.Box():
            with gr.Row().style(mobile_collapse=False, equal_height=True):
                text = gr.Textbox(
                    label="Enter your prompt in Japanese", show_label=False, max_lines=1,
                    placeholder="猫の肖像画 油絵"
                ).style(
                    border=(True, False, True, True),
                    rounded=(True, False, False, True),
                    container=False,
                )
                btn = gr.Button("Run").style(
                    margin=False,
                    rounded=(False, True, True, False),
                )

        # input
        strength_slider = gr.Slider(
            label="Strength",
            maximum=1,
            value=0.75
        )
        image = gr.Image(
            label="Initial Image",
            type="pil",
            value="https://cdn.pixabay.com/photo/2015/11/16/14/43/cat-1045782_960_720.jpg"
        )
        n_samples = gr.Number(value=4, label="n_samples")
        scale = gr.Number(value=7.5, label="cfg_scale")
        steps = gr.Number(value=50, label="steps")
        width = gr.Slider(minimum=64, maximum=2048, value=512, label="width", step=64)
        height = gr.Slider(minimum=64, maximum=2048, value=512, label="height", step=64)
        seed = gr.Textbox(value='random',
                                  placeholder="If you fix seed, you get same outputs all the time. You can set as integer like 42.",
                                  label="seed")

        gallery = gr.Gallery(label="Generated images", show_label=False).style(height="auto")
        text.submit(infer, inputs=[text, image, strength_slider, n_samples, scale, steps, width, height, seed], outputs=gallery)
        btn.click(infer, inputs=[text, image, strength_slider, n_samples, scale, steps, width, height, seed], outputs=gallery)

gr.Markdown(
        """___
   <p style='text-align: center'>
   Created by https://huggingface.co/rinna
   <br/>
   </p>"""
    )

demo.launch(debug=True)