In [1]:
import os
import argparse
import sys
import torch
from pytorch_pretrained_biggan import BigGAN as DMBigGAN
from pytorch_pretrained_biggan import  truncated_noise_sample, one_hot_from_names, display_in_terminal, save_as_images
from matplotlib import pyplot as plt

from pytorch_pretrained_biggan import BigGAN
import clip
import kornia
from PIL import Image
from torchvision.utils import save_image
import torchvision

import numpy as np
import pickle
from pymoo.optimize import minimize
from pymoo.algorithms.so_genetic_algorithm import GA #so_genetic_algorithm
from pymoo.factory import get_algorithm, get_decision_making, get_decomposition
from pymoo.visualization.scatter import Scatter

from scipy.stats import truncnorm

from pymoo.factory import get_sampling, get_crossover, get_mutation
from pymoo.operators.mixed_variable_operator import MixedVariableSampling, MixedVariableMutation, MixedVariableCrossover
from pymoo.model.sampling import Sampling

from pymoo.model.problem import Problem


In [2]:
class TruncatedNormalRandomSampling(Sampling):
    def __init__(self, var_type=np.float64):
        super().__init__()
        self.var_type = var_type

    def _do(self, problem, n_samples, **kwargs):
        return truncnorm.rvs(-2, 2, size=(n_samples, problem.n_var)).astype(np.float64)

class NormalRandomSampling(Sampling):
    def __init__(self, mu=0, std=1, var_type=np.float64):
        super().__init__()
        self.mu = mu
        self.std = std
        self.var_type = var_type

    def _do(self, problem, n_samples, **kwargs):
        return np.random.normal(self.mu, self.std, size=(n_samples, problem.n_var))

class BinaryRandomSampling(Sampling):
    def __init__(self, prob=0.5):
        super().__init__()
        self.prob = prob

    def _do(self, problem, n_samples, **kwargs):
        val = np.random.random((n_samples, problem.n_var))
        return (val < self.prob).astype(np.bool_)

def get_operators(config):
    if config["config"] == "DeepMindBigGAN256" or config["config"] == "DeepMindBigGAN512":
        mask = ["real"]*config["dim_z"] + ["bool"]*config["num_classes"]
        
        real_sampling = None
        if config["config"] == "DeepMindBigGAN256" or config["config"] == "DeepMindBigGAN512":
            real_sampling = TruncatedNormalRandomSampling()

        sampling = MixedVariableSampling(mask, {
            "real": real_sampling,
            "bool": BinaryRandomSampling(prob=5/1000)
        })

        crossover = MixedVariableCrossover(mask, {
            "real": get_crossover("real_sbx", prob=1.0, eta=3.0),
            "bool": get_crossover("bin_hux", prob=0.2)
        })

        mutation = MixedVariableMutation(mask, {
            "real": get_mutation("real_pm", prob=0.5, eta=3.0),
            "bool": get_mutation("bin_bitflip", prob=10/1000)
        })

        return dict(
            sampling=sampling,
            crossover=crossover,
            mutation=mutation
        )

In [3]:
def save_grid(images, path):
    grid = torchvision.utils.make_grid(images)
    torchvision.utils.save_image(grid, path)

def show_grid(images):
    grid = torchvision.utils.make_grid(images)
    plt.imshow(grid.permute(1, 2, 0).cpu().detach().numpy())
    plt.show()

def biggan_norm(images):
    images = (images + 1) / 2.0
    images = images.clip(0, 1)
    return images

def biggan_denorm(images):
    images = images*2 - 1
    return images


def freeze_model(model):
    for param in model.parameters():
        param.requires_grad = False

In [4]:
class DeepMindBigGAN(torch.nn.Module):
    def __init__(self, config):
        super(DeepMindBigGAN, self).__init__()
        self.config = config
        self.G = DMBigGAN.from_pretrained(config["weights"])
        self.D = None

    def has_discriminator(self):
        return False

    def generate(self, z, class_labels, minibatch = None):
        if minibatch is None:
            return self.G(z, class_labels, self.config["truncation"])
        else:
            assert z.shape[0] % minibatch == 0
            gen_images = []
            for i in range(0, z.shape[0] // minibatch):
                z_minibatch = z[i*minibatch:(i+1)*minibatch, :]
                cl_minibatch = class_labels[i*minibatch:(i+1)*minibatch, :]
                gen_images.append(self.G(z_minibatch, cl_minibatch, self.config["truncation"]))
            gen_images = torch.cat(gen_images)
            return gen_images

In [5]:
class GenerationProblem(Problem):
    def __init__(self, config):
        self.generator = Generator(config)
        self.config = config

        super().__init__(**self.config["problem_args"])

    def _evaluate(self, x, out, *args, **kwargs):
        ls = self.config["latent"](self.config)
        ls.set_from_population(x)

        with torch.no_grad():
            generated = self.generator.generate(ls, minibatch=self.config["batch_size"])
            sim = self.generator.clip_similarity(generated).cpu().numpy()
            if self.config["problem_args"]["n_obj"] == 2 and self.config["use_discriminator"]:
                dis = self.generator.discriminate(generated, minibatch=self.config["batch_size"])
                hinge = torch.relu(1 - dis)
                hinge = hinge.squeeze(1).cpu().numpy()
                out["F"] = np.column_stack((-sim, hinge))
            else:
                out["F"] = -sim

            out["G"] = np.zeros((x.shape[0]))

In [6]:
class DeepMindBigGANLatentSpace(torch.nn.Module):
    def __init__(self, config):
        super(DeepMindBigGANLatentSpace, self).__init__()
        self.config = config

        self.z = torch.nn.Parameter(torch.tensor(truncated_noise_sample(self.config["batch_size"])).to(self.config["device"]))
        self.class_labels = torch.nn.Parameter(torch.rand(self.config["batch_size"], self.config["num_classes"]).to(self.config["device"]))
    
    def set_values(self, z, class_labels):
        self.z.data = z
        self.class_labels.data = class_labels

    def set_from_population(self, x):
        self.z.data = torch.tensor(x[:,:self.config["dim_z"]].astype(float)).float().to(self.config["device"])
        self.class_labels.data = torch.tensor(x[:,self.config["dim_z"]:].astype(float)).float().to(self.config["device"])

    def forward(self):
        z = torch.clip(self.z, -2, 2)
        class_labels = torch.softmax(self.class_labels, dim=1)

        return z, class_labels

In [7]:
configs = dict(
    DeepMindBigGAN256 = dict(
        task = "txt2img",
        dim_z = 128,
        num_classes = 1000,
        latent = DeepMindBigGANLatentSpace,
        model = DeepMindBigGAN,
        weights = "biggan-deep-256",
        use_discriminator = False,
        algorithm = "ga",
        norm = biggan_norm,
        denorm = biggan_denorm,
        truncation = 1.0,
        pop_size = 64,
        batch_size = 8,
        problem_args = dict(
            n_var = 128 + 1000,
            n_obj = 1,
            n_constr = 128,
            xl = -2,
            xu = 2
        )
    ),
    DeepMindBigGAN512 = dict(
        task = "txt2img",
        dim_z = 128,
        num_classes = 1000,
        latent = DeepMindBigGANLatentSpace,
        model = DeepMindBigGAN,
        weights = "biggan-deep-512",
        use_discriminator = False,
        algorithm = "ga",
        norm = biggan_norm,
        denorm = biggan_denorm,
        truncation = 1.0, #1.0
        pop_size = 32, #32
        batch_size = 1, #8
        problem_args = dict(
            n_var = 128 + 1000,
            n_obj = 1,
            n_constr = 128,
            xl = -2,
            xu = 2
        )
    )
)



def get_config(name):
    return configs[name]

In [8]:
class Generator:
    def __init__(self, config):
        self.config = config
        self.augmentation = None

        self.CLIP, clip_preprocess = clip.load("ViT-B/32", device=self.config["device"], jit=False)
        
        #Load the fine tuned clip model
        checkpoint = torch.load('./models/model_fastai.pth')
        self.CLIP.load_state_dict(checkpoint['model'])

        self.CLIP = self.CLIP.eval()
        freeze_model(self.CLIP)
        self.model = self.config["model"](config).to(self.config["device"]).eval()
        freeze_model(self.model)
        
        if config["task"] == "txt2img":
            self.tokens = clip.tokenize([self.config["target"]]).to(self.config["device"])
            self.text_features = self.CLIP.encode_text(self.tokens).detach()
        if config["task"] == "img2txt":
            image = clip_preprocess(Image.open(self.config["target"])).unsqueeze(0).to(self.config["device"])
            self.image_features = self.CLIP.encode_image(image)

    def generate(self, ls, minibatch=None):
        z = ls()
        result = self.model.generate(*z, minibatch=minibatch)
        if hasattr(self.config, "norm"):
            result = self.config["norm"](result)
        return result
    
    def discriminate(self, images, minibatch=None):
        images = self.config["denorm"](images)
        return self.model.discriminate(images, minibatch)
    
    def has_discriminator(self):
        return self.model.has_discriminator()

    def clip_similarity(self, input):
        if self.config["task"] == "txt2img":
            image = kornia.resize(input, (224, 224))
            if self.augmentation is not None:
                image = self.augmentation(image)

            image_features = self.CLIP.encode_image(image)
            
            sim = torch.cosine_similarity(image_features, self.text_features)
        elif self.config["task"] == "img2txt":
            try:
                text_tokens = clip.tokenize(input).to(self.config["device"])
            except:
                return torch.zeros(len(input))
            text_features = self.CLIP.encode_text(text_tokens)

            sim = torch.cosine_similarity(text_features, self.image_features)
        return sim


    def save(self, input, path):
        if self.config["task"] == "txt2img":
            if input.shape[0] > 1:
                save_grid(input.detach().cpu(), path)
            else:
                save_image(input[0], path)
        elif self.config["task"] == "img2txt":
            f = open(path, "w")
            f.write("\n".join(input))
            f.close()

In [35]:
config = dict(
        device = "cuda",
        config = "DeepMindBigGAN256",
        generations = 550,
        save_each = 5,
        tmp_folder = "./tmp",
        target = "a woman man and a dog standing in the snow",
)


In [36]:
config.update(get_config(config["config"]))

In [37]:
config

{'device': 'cuda',
 'config': 'DeepMindBigGAN256',
 'generations': 550,
 'save_each': 5,
 'tmp_folder': './tmp',
 'target': 'a woman man and a dog standing in the snow',
 'task': 'txt2img',
 'dim_z': 128,
 'num_classes': 1000,
 'latent': __main__.DeepMindBigGANLatentSpace,
 'model': __main__.DeepMindBigGAN,
 'weights': 'biggan-deep-256',
 'use_discriminator': False,
 'algorithm': 'ga',
 'norm': <function __main__.biggan_norm(images)>,
 'denorm': <function __main__.biggan_denorm(images)>,
 'truncation': 1.0,
 'pop_size': 64,
 'batch_size': 8,
 'problem_args': {'n_var': 1128,
  'n_obj': 1,
  'n_constr': 128,
  'xl': -2,
  'xu': 2}}

In [38]:
iteration = 0
def save_callback(algorithm):
    global iteration
    global config

    iteration += 1
    if iteration % config["save_each"] == 0 or iteration == config["generations"]:
        if config["problem_args"]["n_obj"] == 1:
            sortedpop = sorted(algorithm.pop, key=lambda p: p.F)
            X = np.stack([p.X for p in sortedpop])  
        else:
            X = algorithm.pop.get("X")
        
        ls = config["latent"](config)
        ls.set_from_population(X)

        with torch.no_grad():
            generated = algorithm.problem.generator.generate(ls, minibatch=config["batch_size"])
            if config["task"] == "txt2img":
                ext = "jpg"
            elif config["task"] == "img2txt":
                ext = "txt"
            name = "genetic-it-%d.%s" % (iteration, ext) if iteration < config["generations"] else "genetic-it-final.%s" % (ext, )
            algorithm.problem.generator.save(generated, os.path.join(config["tmp_folder"], name))
        

problem = GenerationProblem(config)
operators = get_operators(config)

if not os.path.exists(config["tmp_folder"]): os.mkdir(config["tmp_folder"])

algorithm = get_algorithm(
    config["algorithm"],
    pop_size=config["pop_size"],
    sampling=operators["sampling"],
    crossover=operators["crossover"],
    mutation=operators["mutation"],
    eliminate_duplicates=True,
    callback=save_callback,
    **(config["algorithm_args"][config["algorithm"]] if "algorithm_args" in config and config["algorithm"] in config["algorithm_args"] else dict())
)

res = minimize(
    problem,
    algorithm,
    ("n_gen", config["generations"]),
    save_history=False,
    verbose=True,
)


pickle.dump(dict(
    X = res.X,
    F = res.F,
    G = res.G,
    CV = res.CV,
), open(os.path.join(config["tmp_folder"], "genetic_result"), "wb"))

if config["problem_args"]["n_obj"] == 2:
    plot = Scatter(labels=["similarity", "discriminator",])
    plot.add(res.F, color="red")
    plot.save(os.path.join(config["tmp_folder"], "F.jpg"))


if config["problem_args"]["n_obj"] == 1:
    sortedpop = sorted(res.pop, key=lambda p: p.F)
    X = np.stack([p.X for p in sortedpop])
else:
    X = res.pop.get("X")

ls = config["latent"](config)
ls.set_from_population(X)

torch.save(ls.state_dict(), os.path.join(config["tmp_folder"], "ls_result"))

if config["problem_args"]["n_obj"] == 1:
    X = np.atleast_2d(res.X)
else:
    try:
        result = get_decision_making("pseudo-weights", [0, 1]).do(res.F)
    except:
        print("Warning: cant use pseudo-weights")
        result = get_decomposition("asf").do(res.F, [0, 1]).argmin()

    X = res.X[result]
    X = np.atleast_2d(X)

ls.set_from_population(X)

with torch.no_grad():
    generated = problem.generator.generate(ls)

if config["task"] == "txt2img":
    ext = "jpg"
elif config["task"] == "img2txt":
    ext = "txt"

problem.generator.save(generated, os.path.join(config["tmp_folder"], "output.%s" % (ext)))

n_gen |  n_eval |   cv (min)   |   cv (avg)   |     fopt     |     favg    
    1 |      64 |  0.00000E+00 |  0.00000E+00 |      -0.2156 |      -0.1327
    2 |     128 |  0.00000E+00 |  0.00000E+00 |      -0.2666 |      -0.1653
    3 |     192 |  0.00000E+00 |  0.00000E+00 |      -0.2666 |      -0.1844
    4 |     256 |  0.00000E+00 |  0.00000E+00 |      -0.2666 |      -0.1958
    5 |     320 |  0.00000E+00 |  0.00000E+00 |      -0.2754 |      -0.2079
    6 |     384 |  0.00000E+00 |  0.00000E+00 |      -0.2754 |      -0.2134
    7 |     448 |  0.00000E+00 |  0.00000E+00 |      -0.2754 |      -0.2201
    8 |     512 |  0.00000E+00 |  0.00000E+00 |      -0.2754 |      -0.2263
    9 |     576 |  0.00000E+00 |  0.00000E+00 |      -0.2754 |      -0.2301
   10 |     640 |  0.00000E+00 |  0.00000E+00 |      -0.2844 |      -0.2361
   11 |     704 |  0.00000E+00 |  0.00000E+00 |      -0.3213 |      -0.2434
   12 |     768 |  0.00000E+00 |  0.00000E+00 |      -0.3213 |      -0.2473
   13 |     

  106 |    6784 |  0.00000E+00 |  0.00000E+00 |       -0.375 |       -0.356
  107 |    6848 |  0.00000E+00 |  0.00000E+00 |       -0.375 |      -0.3564
  108 |    6912 |  0.00000E+00 |  0.00000E+00 |       -0.375 |      -0.3564
  109 |    6976 |  0.00000E+00 |  0.00000E+00 |      -0.3757 |       -0.357
  110 |    7040 |  0.00000E+00 |  0.00000E+00 |      -0.3757 |      -0.3572
  111 |    7104 |  0.00000E+00 |  0.00000E+00 |      -0.3757 |       -0.358
  112 |    7168 |  0.00000E+00 |  0.00000E+00 |      -0.3757 |       -0.358
  113 |    7232 |  0.00000E+00 |  0.00000E+00 |      -0.3757 |       -0.358
  114 |    7296 |  0.00000E+00 |  0.00000E+00 |      -0.3757 |       -0.358
  115 |    7360 |  0.00000E+00 |  0.00000E+00 |      -0.3757 |       -0.358
  116 |    7424 |  0.00000E+00 |  0.00000E+00 |      -0.3757 |       -0.358
  117 |    7488 |  0.00000E+00 |  0.00000E+00 |      -0.3757 |      -0.3582
  118 |    7552 |  0.00000E+00 |  0.00000E+00 |      -0.3757 |      -0.3582
  119 |    7

  214 |   13696 |  0.00000E+00 |  0.00000E+00 |       -0.382 |       -0.365
  215 |   13760 |  0.00000E+00 |  0.00000E+00 |       -0.382 |       -0.365
  216 |   13824 |  0.00000E+00 |  0.00000E+00 |       -0.382 |       -0.365
  217 |   13888 |  0.00000E+00 |  0.00000E+00 |       -0.382 |       -0.365
  218 |   13952 |  0.00000E+00 |  0.00000E+00 |       -0.382 |       -0.365
  219 |   14016 |  0.00000E+00 |  0.00000E+00 |       -0.382 |      -0.3652
  220 |   14080 |  0.00000E+00 |  0.00000E+00 |       -0.382 |      -0.3652
  221 |   14144 |  0.00000E+00 |  0.00000E+00 |       -0.385 |      -0.3655
  222 |   14208 |  0.00000E+00 |  0.00000E+00 |       -0.385 |      -0.3657
  223 |   14272 |  0.00000E+00 |  0.00000E+00 |       -0.385 |      -0.3657
  224 |   14336 |  0.00000E+00 |  0.00000E+00 |       -0.385 |      -0.3657
  225 |   14400 |  0.00000E+00 |  0.00000E+00 |       -0.385 |       -0.366
  226 |   14464 |  0.00000E+00 |  0.00000E+00 |       -0.385 |       -0.366
  227 |   14

  322 |   20608 |  0.00000E+00 |  0.00000E+00 |       -0.385 |      -0.3699
  323 |   20672 |  0.00000E+00 |  0.00000E+00 |       -0.385 |      -0.3699
  324 |   20736 |  0.00000E+00 |  0.00000E+00 |       -0.385 |      -0.3699
  325 |   20800 |  0.00000E+00 |  0.00000E+00 |       -0.385 |      -0.3699
  326 |   20864 |  0.00000E+00 |  0.00000E+00 |       -0.385 |      -0.3699
  327 |   20928 |  0.00000E+00 |  0.00000E+00 |       -0.385 |      -0.3699
  328 |   20992 |  0.00000E+00 |  0.00000E+00 |       -0.385 |      -0.3699
  329 |   21056 |  0.00000E+00 |  0.00000E+00 |       -0.385 |      -0.3699
  330 |   21120 |  0.00000E+00 |  0.00000E+00 |       -0.385 |      -0.3699
  331 |   21184 |  0.00000E+00 |  0.00000E+00 |       -0.385 |      -0.3699
  332 |   21248 |  0.00000E+00 |  0.00000E+00 |       -0.385 |      -0.3699
  333 |   21312 |  0.00000E+00 |  0.00000E+00 |       -0.385 |      -0.3699
  334 |   21376 |  0.00000E+00 |  0.00000E+00 |       -0.385 |      -0.3704
  335 |   21

  430 |   27520 |  0.00000E+00 |  0.00000E+00 |       -0.385 |      -0.3723
  431 |   27584 |  0.00000E+00 |  0.00000E+00 |       -0.385 |      -0.3723
  432 |   27648 |  0.00000E+00 |  0.00000E+00 |       -0.385 |      -0.3723
  433 |   27712 |  0.00000E+00 |  0.00000E+00 |       -0.385 |      -0.3723
  434 |   27776 |  0.00000E+00 |  0.00000E+00 |       -0.385 |      -0.3723
  435 |   27840 |  0.00000E+00 |  0.00000E+00 |       -0.385 |      -0.3723
  436 |   27904 |  0.00000E+00 |  0.00000E+00 |       -0.385 |      -0.3723
  437 |   27968 |  0.00000E+00 |  0.00000E+00 |       -0.385 |      -0.3723
  438 |   28032 |  0.00000E+00 |  0.00000E+00 |       -0.385 |      -0.3723
  439 |   28096 |  0.00000E+00 |  0.00000E+00 |       -0.385 |      -0.3723
  440 |   28160 |  0.00000E+00 |  0.00000E+00 |       -0.385 |      -0.3723
  441 |   28224 |  0.00000E+00 |  0.00000E+00 |       -0.385 |      -0.3723
  442 |   28288 |  0.00000E+00 |  0.00000E+00 |       -0.385 |      -0.3723
  443 |   28

  538 |   34432 |  0.00000E+00 |  0.00000E+00 |      -0.4004 |      -0.3745
  539 |   34496 |  0.00000E+00 |  0.00000E+00 |      -0.4004 |      -0.3745
  540 |   34560 |  0.00000E+00 |  0.00000E+00 |      -0.4004 |      -0.3745
  541 |   34624 |  0.00000E+00 |  0.00000E+00 |      -0.4004 |      -0.3745
  542 |   34688 |  0.00000E+00 |  0.00000E+00 |      -0.4004 |      -0.3745
  543 |   34752 |  0.00000E+00 |  0.00000E+00 |      -0.4004 |      -0.3745
  544 |   34816 |  0.00000E+00 |  0.00000E+00 |      -0.4004 |      -0.3745
  545 |   34880 |  0.00000E+00 |  0.00000E+00 |      -0.4004 |      -0.3745
  546 |   34944 |  0.00000E+00 |  0.00000E+00 |      -0.4004 |      -0.3748
  547 |   35008 |  0.00000E+00 |  0.00000E+00 |      -0.4004 |      -0.3748
  548 |   35072 |  0.00000E+00 |  0.00000E+00 |      -0.4004 |      -0.3748
  549 |   35136 |  0.00000E+00 |  0.00000E+00 |      -0.4004 |      -0.3748
  550 |   35200 |  0.00000E+00 |  0.00000E+00 |      -0.4004 |      -0.3748


In [None]:
import gc

In [None]:
torch.cuda.empty_cache()
gc.collect()