In [None]:
from pathlib import Path
import time
from random import randint

from diffusers import StableDiffusionPipeline
from PIL import Image, ImageDraw, ImageFont
from tqdm import tqdm
import torch
from torch import autocast

from countries import countries

In [None]:
DEVICE = "cuda"

# Either download the model repo from
# here: https://huggingface.co/lambdalabs/sd-pokemon-diffusers
# and use a local path,
# or set MODEL_PATH="lambdalabs/sd-pokemon-diffusers"
# to automatically download the model from the HuggingFace hub.
MODEL_PATH = "/home/sid/Desktop/sd-pokemon-diffusers"

In [None]:
pipe = StableDiffusionPipeline.from_pretrained(
    MODEL_PATH,
    revision="fp16",
    torch_dtype=torch.float16,
)
pipe = pipe.to(DEVICE)
pipe.enable_attention_slicing()
# turn off to prevent false positive (unlikely to get NSFW in this context)
pipe.safety_checker = lambda images, **kwargs: (images, False)

In [None]:
def generate_image(prompt: str, seed=None):
    generator = torch.Generator(DEVICE).manual_seed(seed) if seed else None
    with autocast(DEVICE):
        image = pipe(prompt, generator=generator).images[0]
    return image

In [None]:
savedir = Path(f"images/{time.strftime('%Y%m%d-%H%M%S')}")
savedir.mkdir(exist_ok=True)
for country_name in tqdm(countries):
    prompt = f"The country of {country_name} as a pokemon"
    seed = randint(0, 10_000)  # so can reproduce a given example
    image = generate_image(prompt, seed)
    # Annotate the image with the country name
    base_width, base_height = image.size
    annotated_image = Image.new("RGB", (base_width, base_height + 60), (255, 255, 255))
    annotated_image.paste(image, (0, 0))
    draw = ImageDraw.Draw(annotated_image)
    font = ImageFont.truetype("OpenSans-Regular.ttf", 40)
    draw.text((0, base_height), f"{country_name}", (0, 0, 0), font=font)
    # Save it
    annotated_image.save(savedir / f"{prompt}_seed{seed}.png")