<a href="https://colab.research.google.com/github/yuvalsolaz/hanglider/blob/main/text2image_collab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import sys
import numpy as np
import torch as th
from PIL import Image
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
from IPython import display
from google.colab import files 

In [2]:
!pip install ipyplot
import ipyplot


        You might encounter issues while running in Google Colab environment.
        If images are not displaying properly please try setting `force_b64` param to `True`.
        


In [3]:
has_cuda = th.cuda.is_available()
device = th.device('cpu' if not has_cuda else 'cuda')
print(f'using: {device} device\n')

using: cuda device



In [4]:
# Sampling parameters
batch_size = 1
guidance_scale = 3.0
diffusion_steps = {'base': '100',       # use diffusion steps for fast sampling   in base model 
                   'upsample':'fast27'} # use upsample diffusion steps for very fast sampling in upsampling model

In [5]:
!pip install git+https://github.com/openai/glide-text2im

Collecting git+https://github.com/openai/glide-text2im
  Cloning https://github.com/openai/glide-text2im to /tmp/pip-req-build-_36f4ahj
  Running command git clone -q https://github.com/openai/glide-text2im /tmp/pip-req-build-_36f4ahj


In [None]:
# Load / Create base model.
from glide_text2im.download import load_checkpoint
from glide_text2im.model_creation import create_model_and_diffusion, model_and_diffusion_defaults

print(f'use {diffusion_steps} diffusion steps for fast sampling')
options = model_and_diffusion_defaults()
options['use_fp16'] = has_cuda
options['timestep_respacing'] = diffusion_steps['base']  
model, diffusion = create_model_and_diffusion(**options)
model.eval()
if has_cuda:
    model.convert_to_fp16()
    model.to(device)
    model.load_state_dict(load_checkpoint('base', device))
print('total base parameters', sum(x.numel() for x in model.parameters()))

use {'base': '100', 'upsample': 'fast27'} diffusion steps for fast sampling


Text2ImUNet(
  (time_embed): Sequential(
    (0): Linear(in_features=192, out_features=768, bias=True)
    (1): SiLU()
    (2): Linear(in_features=768, out_features=768, bias=True)
  )
  (input_blocks): ModuleList(
    (0): TimestepEmbedSequential(
      (0): Conv2d(3, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (1): TimestepEmbedSequential(
      (0): ResBlock(
        (in_layers): Sequential(
          (0): GroupNorm32(32, 192, eps=1e-05, affine=True)
          (1): Identity()
          (2): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (h_upd): Identity()
        (x_upd): Identity()
        (emb_layers): Sequential(
          (0): SiLU()
          (1): Linear(in_features=768, out_features=384, bias=True)
        )
        (out_layers): Sequential(
          (0): GroupNorm32(32, 192, eps=1e-05, affine=True)
          (1): SiLU()
          (2): Dropout(p=0.1, inplace=False)
          (3): Conv2d(192, 192, kernel_size=(3, 3)

In [None]:
 # Load / Create upsampler model.
from glide_text2im.model_creation import model_and_diffusion_defaults_upsampler

options_up = model_and_diffusion_defaults_upsampler()
options_up['use_fp16'] = has_cuda
# options_up['timestep_respacing'] = diffusion_steps['upsample']

model_up, diffusion_up = create_model_and_diffusion(**options_up)
model_up.eval()
if has_cuda:
    model_up.convert_to_fp16()
    model_up.to(device)
    model_up.load_state_dict(load_checkpoint('upsample', device))

print('total upsampler parameters', sum(x.numel() for x in model_up.parameters()))

In [None]:
def create_pil_image(batch: th.Tensor):
    scaled = ((batch + 1)*127.5).round().clamp(0,255).to(th.uint8).cpu()
    reshaped = scaled.permute(2, 0, 3, 1).reshape([batch.shape[2], -1, 3])
    npimage = reshaped.numpy()
    im = Image.fromarray(npimage)
    return im.resize(size=(256,256))

def create_pygame_image(batch_tuple: th.Tensor):
    batch = batch_tuple[0]
    scaled = ((batch + 1) * 127.5).round().clamp(0, 255).to(th.uint8).cpu()
    reshaped = scaled.permute(1 ,2 ,0).reshape([batch.shape[2], -1, 3])
    npimage = reshaped.numpy()
    im = Image.fromarray(npimage)
    return im.resize(size=(256,256))

def image_show_pil(batch: th.Tensor, caption=''):
    create_pil_image(batch).show(title=caption)

In [None]:
image_array = []
def generate_image(prompt:str):
    # Tune this parameter to control the sharpness of 256x256 images.
    # A value of 1.0 is sharper, but sometimes results in grainy artifacts.
    upsample_temp = 0.997

    eps_arr = []
    rest_arr = []
    # Create a classifier-free guidance sampling function
    def model_fn(x_t, ts, **kwargs):
        half = x_t[: len(x_t) // 2]
        combined = th.cat([half, half], dim=0)
        model_out = model(combined, ts, **kwargs)
        eps, rest = model_out[:, :3], model_out[:, 3:]
        eps_arr.append(eps)
        rest_arr.append(rest)
        cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0)
        half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
        eps = th.cat([half_eps, half_eps], dim=0)
        image = th.cat([eps, rest], dim=1)
        image_array.append(create_pygame_image(rest))
        
        return image

    ##############################
    # Sample from the base model #
    ##############################

    # Create the text tokens to feed to the model.
    tokens = model.tokenizer.encode(prompt)
    tokens, mask = model.tokenizer.padded_tokens_and_mask(tokens, options['text_ctx'])

    # Create the classifier-free guidance tokens (empty)
    full_batch_size = batch_size * 2
    uncond_tokens, uncond_mask = model.tokenizer.padded_tokens_and_mask([], options['text_ctx'])

    # Pack the tokens together into model kwargs.
    model_kwargs = dict(
        tokens=th.tensor(
            [tokens] * batch_size + [uncond_tokens] * batch_size, device=device
        ),
        mask=th.tensor(
            [mask] * batch_size + [uncond_mask] * batch_size,
            dtype=th.bool,
            device=device,
        ),
    )

    # Sample from the base model.
    model.del_cache()
    samples = diffusion.p_sample_loop(
        model_fn,
        (full_batch_size, 3, options["image_size"], options["image_size"]),
        device=device,
        clip_denoised=True,
        progress=True,
        model_kwargs=model_kwargs,
        cond_fn=None,
    )[:batch_size]
    model.del_cache()
    return samples, eps_arr, rest_arr

#### enter image caption for image generation: 

In [None]:
def make_gif(frames,  gif_name):
    frame_one = frames[0]
    steps = len(frames)
    y = np.log(steps+1) - np.log(np.arange(1,steps)) 
    ijump = [int(y[:i].sum()) for i in range(steps)]     
    gif_frames = [frames[i] for i in ijump]
    gif_file = f"{gif_name}.gif"
    frame_one.save(gif_file, format="GIF", append_images=gif_frames,
                   save_all=True, duration=100, loop=0)
    return gif_file

In [None]:
prompt = input('enter image caption for image generation: ')
print (f'generating image for: {prompt}')
images = []
for i in range(1):
    samples,eps_arr, rest_arr  = generate_image(prompt=prompt)
    im = create_pil_image(samples)
    images.append(im)
    ipyplot.plot_images([im])

In [None]:

gif_file_name = make_gif(gif_name=prompt,frames=image_array)

In [None]:
# Display GIF in Jupyter, CoLab, IPython
with open(gif_file_name,'rb') as f:
    display.Image(data=f.read(), format='png')

In [None]:
# Download GIF to your local computer
files.download(filename=gif_file_name)