Citation : https://github.com/kobiso/DALLE-reproduction

In [2]:
# For Google Collab only
# import platform
# print(platform.dist())
from google.colab import drive
drive.mount('/content/drive')

!cp /content/drive/MyDrive/DL/DALLE-reproduction-main/requirements.txt .

# Install the required libraries to runtime.

# Torchtest package
!pip install -r requirements.txt
!pip install torch==1.8.1

Mounted at /content/drive
Collecting git+https://github.com/openai/CLIP.git (from -r requirements.txt (line 1))
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-jzdvymvy
  Running command git clone -q https://github.com/openai/CLIP.git /tmp/pip-req-build-jzdvymvy
Collecting taming-transformers
[?25l  Downloading https://files.pythonhosted.org/packages/8a/c2/ae7227e4b089c6a8210920db9d5ac59186b0a84eb1e6d96b9218916cdaf1/taming_transformers-0.0.1-py3-none-any.whl (45kB)
[K     |████████████████████████████████| 51kB 4.3MB/s 
[?25hCollecting dalle-pytorch==0.7.2
[?25l  Downloading https://files.pythonhosted.org/packages/8d/9a/561a042a32a82b7a83cbb688d0b4103a46eb8a3f83b985635dee71201bb6/dalle_pytorch-0.7.2-py3-none-any.whl (1.4MB)
[K     |████████████████████████████████| 1.4MB 11.4MB/s 
[?25hCollecting tokenizers
[?25l  Downloading https://files.pythonhosted.org/packages/ae/04/5b870f26a858552025a62f1649c20d29d2672c02ff3c3fb4c688ca46467a/tokenizers-0.10.2-cp37-cp37m-

In [3]:
# !pip install torch==1.8.1
# !pip install torchvision==0.9.0

import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="7"

# !pip install 'dalle-pytorch==0.7.2'

import dalle_pytorch
import torch
import torch.nn.functional as F
from torchvision.utils import make_grid, save_image
from dalle_pytorch import VQGanVAE1024, DALLE
from tokenizers import Tokenizer

%matplotlib inline
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
from PIL import Image
import numpy as np
import clip

def show(img):
    npimg = img.numpy()
    plt.figure(figsize = (100,40))
    plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')

### Prepare VQGAN and BPE

In [4]:
DATA = "cub" # Change here to "cub" or "coco"

if DATA == "cub":
    BPE_path = "/content/drive/MyDrive/DL/DALLE-reproduction-main/BPE/cub200_bpe_vsize_7800.json"
    dalle_path = "/content/drive/MyDrive/DL/DALLE-reproduction-main/pretrained/cub200_adam_frcc.pth"
elif DATA == "coco":
    BPE_path = "BPE/coco_bpe_vsize_15000.json"
    dalle_path = "pretrained/coco_adam.pth"

vae_dict = {"args": {"image_size": 256, "emb_dim": 256}}
vae = VQGanVAE1024()

vocab = Tokenizer.from_file(BPE_path)

100%|███████████████████████████████████████| 645/645 [00:00<00:00, 5477.56it/s]
100%|███████████████████████| 957954257/957954257 [00:45<00:00, 21263348.47it/s]


Working with z of shape (1, 256, 16, 16) = 65536 dimensions.


Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth


HBox(children=(FloatProgress(value=0.0, max=553433881.0), HTML(value='')))


Downloading vgg_lpips model from https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1 to taming/modules/autoencoder/lpips/vgg.pth


8.19kB [00:00, 351kB/s]                    


loaded pretrained LPIPS loss from taming/modules/autoencoder/lpips/vgg.pth
VQLPIPSWithDiscriminator running with hinge loss.


### Load pretrained DALLE model

In [5]:
# !pip install pytorch torchvision cudatoolkit=10.1 -c pytorch

# Prepare DALLE
dalle_dict = torch.load(dalle_path, map_location=('cpu'))
# dalle_dict = torch.load(dalle_path)

# Reformat attention types

attn_types = []
for type in dalle_dict['args']['attn_types'].split(","):
    assert type in ("full", "sparse", "axial_row", "axial_col", "conv_like")
    attn_types.append(type)
attn_types = tuple(attn_types)
print("[Log] Attention types: ", attn_types)

dalle = DALLE(
        dim=vae_dict['args']['emb_dim'],
        vae=vae,
        num_text_tokens=dalle_dict['args']['num_text_tokens'],
        text_seq_len=dalle_dict['args']['text_seq_len'],
        depth=dalle_dict['args']['depth'],
        heads=dalle_dict['args']['heads'],
        reversible=dalle_dict['args']['reversible'],
        attn_types=attn_types,
)

dalle.load_state_dict(dalle_dict['dalle'])

[Log] Attention types:  ('full', 'axial_row', 'axial_col', 'conv_like')


<All keys matched successfully>

## Generate image from text

In [6]:
input_text = ["the medium sized bird has a dark grey color, a black downward curved beak, and long wings.",
"the bird is dark grey brown with a thick curved bill and a flat shaped tail.",
"bird has brown body feathers, white breast feathers and black beak",
"this bird has a dark brown overall body color, with a small white patch around the base of the bill.",
"the bird has very long and large brown wings, as well as a black body and a long black beak.",
"it is a type of albatross with black wings, tail, back and beak, and has a white ring at the base of its beak.",
"this bird has brown plumage and a white ring at the base of its long, curved brown beak.",
"the entire body is dark brown, as is the bill, with a white band encircling where the bill meets the head.",
"this bird is gray in color, with a large curved beak.",
"a large gray bird with a long wingspan and a long black beak."
]
token_list = []
sot_token = vocab.encode("<|startoftext|>").ids[0]
eot_token = vocab.encode("<|endoftext|>").ids[0]
for txt in input_text:
    codes = [0] * dalle_dict['args']['text_seq_len']
    text_token = vocab.encode(txt).ids
    tokens = [sot_token] + text_token + [eot_token]
    codes[:len(tokens)] = tokens
    # caption_token = torch.LongTensor(codes).cuda()
    caption_token = torch.LongTensor(codes)
    token_list.append(caption_token)
text = torch.stack(token_list)
# mask = (text != 0).cuda()
mask = (text != 0)

image_name = 0

images = dalle.generate_images(text, mask = mask, filter_thres = 0.9, temperature=1.0)
print("/content/drive/MyDrive/DL/DALLE-reproduction-main/"+ str(image_name) +"_result.png")

for image in images:
  # image.save("/content/drive/MyDrive/DL/DALLE-reproduction-main/"+image_name+"_result.png", format="png")
  save_image(image, "/content/drive/MyDrive/DL/DALLE-reproduction-main/"+ str(image_name) +"_result.png")
  image_name += 1


# original. save("converted.png", format="png")

grid = make_grid(images, nrow=4, normalize=False, range=(-1, 1)).cpu()
show(grid)

Output hidden; open in https://colab.research.google.com to view.

In [7]:
# import pickle

# pickle_list =[]
# pickle_file = open("/content/drive/MyDrive/DL/DALLE-reproduction-main/StackGAN-v2-master_data_birds_test_filenames.pickle", 'rb')
# while True:
#     try:
#         pickle_list.append(pickle.load(pickle_file))
#     except EOFError:
#         break
# print(pickle_list)
# pickle_file.close()

[['001.Black_footed_Albatross/Black_Footed_Albatross_0046_18', '001.Black_footed_Albatross/Black_Footed_Albatross_0009_34', '001.Black_footed_Albatross/Black_Footed_Albatross_0002_55', '001.Black_footed_Albatross/Black_Footed_Albatross_0074_59', '001.Black_footed_Albatross/Black_Footed_Albatross_0014_89', '001.Black_footed_Albatross/Black_Footed_Albatross_0085_92', '001.Black_footed_Albatross/Black_Footed_Albatross_0031_100', '001.Black_footed_Albatross/Black_Footed_Albatross_0051_796103', '001.Black_footed_Albatross/Black_Footed_Albatross_0010_796097', '001.Black_footed_Albatross/Black_Footed_Albatross_0025_796057', '001.Black_footed_Albatross/Black_Footed_Albatross_0023_796059', '001.Black_footed_Albatross/Black_Footed_Albatross_0086_796062', '001.Black_footed_Albatross/Black_Footed_Albatross_0049_796063', '001.Black_footed_Albatross/Black_Footed_Albatross_0006_796065', '001.Black_footed_Albatross/Black_Footed_Albatross_0040_796066', '001.Black_footed_Albatross/Black_Footed_Albatross

In [8]:
# input_text = ["this colorful bird has a yellow breast , with a black crown and a black cheek patch"] * 32

# token_list = []
# sot_token = vocab.encode("<|startoftext|>").ids[0]
# eot_token = vocab.encode("<|endoftext|>").ids[0]
# for txt in input_text:
#     codes = [0] * dalle_dict['args']['text_seq_len']
#     text_token = vocab.encode(txt).ids
#     tokens = [sot_token] + text_token + [eot_token]
#     codes[:len(tokens)] = tokens
#     # caption_token = torch.LongTensor(codes).cuda()
#     caption_token = torch.LongTensor(codes)
#     token_list.append(caption_token)
# text = torch.stack(token_list)
# # mask = (text != 0).cuda()
# mask = (text != 0)


# images = dalle.generate_images(text, mask = mask, filter_thres = 0.9, temperature=1.0)

# grid = make_grid(images, nrow=4, normalize=False, range=(-1, 1)).cpu()
# show(grid)

KeyboardInterrupt: ignored

Citation : https://github.com/kobiso/DALLE-reproduction