**This notebook contains weights for a decent vae and a small dalle that needs way more training**

 - trained on competition data using the code from https://github.com/lucidrains/DALLE-pytorch 

*lucid is all you need*

This is mostly just for fun, but if you figure out how to use this to improve the score, please let everyone know, thanks.

In [None]:
# !conda install -y -c rdkit rdkit
!pip install rdkit-pypi dalle-pytorch youtokentome einops

In [None]:
import numpy as np, os
import pandas as pd
import re
from tqdm.auto import tqdm
tqdm.pandas()

import torch
from torch.utils.data import Dataset
import cv2

from albumentations import (
    Compose, OneOf, Normalize, Resize, HorizontalFlip, VerticalFlip, Rotate, RandomRotate90, CenterCrop
    )

from albumentations.pytorch import ToTensorV2

import rdkit.Chem as Chem

from pathlib import Path
from tqdm import tqdm

# torch

import torch

from einops import repeat

# vision imports

from PIL import Image
from torchvision.utils import make_grid, save_image

# dalle related classes and utils

from dalle_pytorch import DiscreteVAE, OpenAIDiscreteVAE, VQGanVAE1024, DALLE
from dalle_pytorch.tokenizer import tokenizer, HugTokenizer, YttmTokenizer, ChineseTokenizer

from matplotlib import pyplot as plt

## functions to generate images via rdkit from this excellent notebook: https://www.kaggle.com/tuckerarrants/inchi-allowed-external-data

def sp_noise(image):
    #https://gist.github.com/lucaswiman/1e877a164a69f78694f845eab45c381a
    output = image.copy()
    if len(image.shape) == 2:
        black = 0
        white = 255            
    else:
        colorspace = image.shape[2]
        if colorspace == 3:  # RGB
            black = np.array([0, 0, 0], dtype='uint8')
            white = np.array([255, 255, 255], dtype='uint8')
        else:  # RGBA
            black = np.array([0, 0, 0, 255], dtype='uint8')
            white = np.array([255, 255, 255, 255], dtype='uint8')
    probs = np.random.random(image.shape[:2])
    image[probs < .00015] = black
    image[probs > .85] = white
    return image

def noisy_inchi(inchi, inchi_path='tmp1.png', add_noise=True, crop_and_pad=True):
    mol = Chem.MolFromInchi(inchi)
    d = Chem.Draw.rdMolDraw2D.MolDraw2DCairo(512, 512)
    # https://www.kaggle.com/stainsby/improved-synthetic-data-for-bms-competition-v3
    Chem.rdDepictor.SetPreferCoordGen(True)
    d.drawOptions().maxFontSize=16
    d.drawOptions().multipleBondOffset=np.random.uniform(0.05, 0.2)
    d.drawOptions().useBWAtomPalette()
    d.drawOptions().bondLineWidth=2
    d.drawOptions().additionalAtomLabelPadding=np.random.uniform(0, .2)
    d.DrawMolecule(mol)
    d.FinishDrawing()
    d.WriteDrawingText(inchi_path)  
    if crop_and_pad:
        img = cv2.imread(inchi_path, cv2.IMREAD_GRAYSCALE)
        crop_rows = img[~np.all(img==255, axis=1), :]
        img = crop_rows[:, ~np.all(crop_rows==255, axis=0)]
        img = cv2.copyMakeBorder(img, 30, 20, 30, 20, cv2.BORDER_CONSTANT, value=255)
        img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
    else:
        img = cv2.imread(inchi_path)
    if add_noise:
        img = sp_noise(img)
#         cv2.imwrite(inchi_path, img)
    return img

In [None]:
df = pd.read_pickle('../input/extra-inchis/allowed_inchi_processed.pkl')

In [None]:
rows = df.sample(n=8)[['InChI', 'InChI_text']].values

In [None]:
inchis, texts = rows[:,0], rows[:, 1]

In [None]:
def display_multiple_img(images, rows = 1, cols=1):
    figure, ax = plt.subplots(nrows=rows,ncols=cols )
    for ind,title in enumerate(images):
        ax.ravel()[ind].imshow(images[title])
        ax.ravel()[ind].set_title(title)
        ax.ravel()[ind].set_axis_off()
    plt.tight_layout()
    plt.gcf().set_size_inches(12,8)
    plt.show()
    
def show_generated(output, raw_text):

    imgs = [h.cpu().permute(1,2,0).numpy()*2+1 for h in output.unbind(0)]+[noisy_inchi(raw_text), noisy_inchi(raw_text)]
    ttls = ['dalle_'+str(i) for i in range(len(output))] + ['rdkit_1', 'rdkit_2']

    images = {ttls[i]: imgs[i] for i in range(len(ttls))}

    display_multiple_img(images, 2, 3)

In [None]:
class Args:
    dalle_path = '../input/moldal/dalle_gen2.pth'
    text = ''
    num_images = 4
    batch_size = 4
    top_k = 0.97
    outputs_dir = './'
    bpe_path = '../input/moldal/bms_bpe.model'
    hug = False
    chinese = False
    taming = False

args=Args()

def exists(val):
    return val is not None

if exists(args.bpe_path):
    klass = HugTokenizer if args.hug else YttmTokenizer
    tokenizer = klass(args.bpe_path)
elif args.chinese:
    tokenizer = ChineseTokenizer()

# load DALL-E

dalle_path = Path(args.dalle_path)

assert dalle_path.exists(), 'trained DALL-E must exist'

load_obj = torch.load(str(dalle_path), map_location='cpu')

dalle_params, vae_params, weights = load_obj.pop('hparams'), load_obj.pop('vae_params'), load_obj.pop('weights')

dalle_params.pop('vae', None) # cleanup later

import gc
gc.collect()

vae = DiscreteVAE(**vae_params)

dalle = DALLE(vae = vae, **dalle_params).cuda()

dalle.load_state_dict(weights)

image_size = vae.image_size    

In [None]:
for inchi, text in zip(inchis, texts):
    text = tokenizer.tokenize(text,
                              dalle.text_seq_len,
                              truncate_text=True,
                             ).cuda()

    text = repeat(text, '() n -> b n', b = args.num_images)
    for text_chunk in tqdm(text.split(args.batch_size), desc = f'generating images for - {inchi}'):
        output = dalle.generate_images(text_chunk, filter_thres = args.top_k)
        show_generated(output, inchi)

In [None]:
# for inchi, text in zip(inchis, texts):
#     text = tokenizer.tokenize(text,
#                               dalle.text_seq_len,
#                               truncate_text=True,
#                              ).cuda()

#     text = repeat(text, '() n -> b n', b = args.num_images)
#     for text_chunk in tqdm(text.split(args.batch_size), desc = f'generating images for - {inchi}'):
#         output = dalle.generate_images(text_chunk, filter_thres = args.top_k)
#         show_generated(output, inchi)
#         print(text_chunk)
#         break
#     break