In [1]:
import numpy as np
import nnabla as nn
import nnabla.functions as F
import nnabla.parametric_functions as PF
import nnabla.solvers as S
from nnabla.ext_utils import get_extension_context
from nnabla.utils.image_utils import imsave

from PIL import Image
import numpy as np

from gan.generate import synthesis
from gan.networks import mapping_network, conv_block
from gan.ops import upsample_2d, upsample_conv_2d, lerp, convert_images_to_uint8, weight_init_fn

import clip

2021-09-15 09:49:25,616 [nnabla][INFO]: Initializing CPU extension...


In [2]:
# params
SEED = 66
batch_size = 1

LR = 1e-2
WEIGHT_DECAY = 1e-5

EPOCHS = 200
diff_epoch = 110

truncation_psi = 0.5
resolution = 1024

imsave_freq = 5

use_l2 = False
l2_lambda = 0.008

gan_path = './gan/face.h5'

context = 'cudnn'
ctx = get_extension_context(context)
nn.set_default_context(ctx)

2021-09-15 09:49:28,440 [nnabla][INFO]: Initializing CUDA extension...
2021-09-15 09:49:28,455 [nnabla][INFO]: Initializing cuDNN extension...


In [3]:
nn.set_auto_forward(True)

### input text

In [4]:
text = "a man with blonde hair"

### loss func for clip

In [5]:
def _normalize(img, mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711), max_pixel_value=1.0):
    mean = np.array(mean, dtype=np.float32)
    mean = nn.Variable.from_numpy_array(mean)
    mean *= max_pixel_value

    std = np.array(std, dtype=np.float32)
    std = nn.Variable.from_numpy_array(std)
    std *= max_pixel_value

    denominator = F.r_div_scalar(std)

    mean = F.reshape(mean, (3, 1, 1))
    denominator = F.reshape(denominator, (3, 1, 1))
    
    img -= mean
    img *= denominator
    return img

def clip_loss(image, text):
    with nn.parameter_scope('clip'):
    
        image = F.interpolate(image, output_size=(224, 224))
        image = _normalize(image)

        text = clip.tokenize(text)

        img_logits, _ = clip.logits(image, text)
        similarity = 1 - img_logits / 100
    
    return similarity

In [6]:
with nn.parameter_scope('clip'):
    clip.load('data/ViT-B-32.h5')

In [7]:
!

### learning params??

In [8]:
rnd = np.random.RandomState(SEED)
z = rnd.randn(batch_size, 512)

style_noises = [nn.NdArray.from_numpy_array(z) for _ in range(2)]

In [9]:
# loading styleGAN parmeters
with nn.parameter_scope('gan'):
    nn.load_parameters(gan_path)
    
    constant = nn.parameter.get_parameter_or_create(
                name="G_synthesis/4x4/Const/const",shape=(1, 512, 4, 4))
    constant_bc = F.broadcast(constant, (batch_size,) + constant.shape[1:])
    dlatent_avg = nn.parameter.get_parameter_or_create(
                name="dlatent_avg", shape=(1, 512))

    style_noises_normalized = []
    for style_noise in style_noises:
        noise_std = (F.mean(style_noise ** 2., axis=1,
                            keepdims=True)+1e-8) ** 0.5
        style_noise_normalized = F.div2(style_noise, noise_std)
        style_noises_normalized.append(style_noise_normalized)

    w = [mapping_network(_, outmaps=512) for _ in style_noises_normalized]

In [10]:
latent = [_ for _ in range(2)]
latent_init = [_ for _ in range(2)]

# create parameter to learn
latent[0] =  nn.parameter.get_parameter_or_create(
    name="latent/0", shape=w[0].shape, initializer=w[0].data, need_grad=True)
latent[1] =  nn.parameter.get_parameter_or_create(
    name="latent/1", shape=w[1].shape, initializer=w[1].data, need_grad=True)


latent_init[0] = nn.Variable.from_numpy_array(latent[0].d.copy())
latent_init[1] = nn.Variable.from_numpy_array(latent[1].d.copy())

diff = []

In [11]:
solver = S.Adam(alpha=LR)
with nn.parameter_scope('latent'):
    solver.set_parameters(nn.get_parameters())

### Train loop

In [12]:
for epoch in range(EPOCHS):
    # normalize noise inputs
    with nn.parameter_scope('gan'):
        # new latent space variable
        w = [lerp(dlatent_avg, _, truncation_psi) for _ in latent]
        rgb_output = synthesis(w, constant_bc, 1, 7, resolution=resolution)

    
    if epoch % imsave_freq == 0:
        img = convert_images_to_uint8(rgb_output, drange=[-1, 1])
        imsave(f'results/{epoch}.png', img[0], channel_first=True)
    
    
    if use_l2:
        l2_loss = F.sum((latent_init[0] - latent[0]) ** 2) + F.sum((latent_init[1] - latent[1]) ** 2)
        l2_loss = l2_loss.reshape((1, 1))
        loss = clip_loss(rgb_output[0], text) + l2_lambda * l2_loss
    else:
        loss = clip_loss(rgb_output[0], text)
        
    if epoch == diff_epoch:
        diff.append((latent[0] - latent_init[0]).d.copy())
        diff.append((latent[1] - latent_init[1]).d.copy())
    
    solver.zero_grad()
    loss.backward(clear_buffer=True)
    solver.update()
    
    if epoch % 10 == 0:
        print('epoch: {}/{} - loss: {:.5f}'.format(epoch, EPOCHS, float(loss.d)))

epoch: 0/200 - loss: 0.75592
epoch: 10/200 - loss: 0.71544
epoch: 20/200 - loss: 0.69196
epoch: 30/200 - loss: 0.67529
epoch: 40/200 - loss: 0.66550
epoch: 50/200 - loss: 0.65958
epoch: 60/200 - loss: 0.65478
epoch: 70/200 - loss: 0.65097
epoch: 80/200 - loss: 0.64785
epoch: 90/200 - loss: 0.64506
epoch: 100/200 - loss: 0.64204
epoch: 110/200 - loss: 0.63864
epoch: 120/200 - loss: 0.63538
epoch: 130/200 - loss: 0.63249
epoch: 140/200 - loss: 0.62984
epoch: 150/200 - loss: 0.62685
epoch: 160/200 - loss: 0.62370
epoch: 170/200 - loss: 0.62071
epoch: 180/200 - loss: 0.61771
epoch: 190/200 - loss: 0.61514


### Image check

In [13]:
import cv2
# encoder(for mp4)
fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')
# output file name, encoder, fps, size(fit to image size)
video = cv2.VideoWriter('video.mp4',fourcc, 1, (1024, 1024))


for i in range(0, EPOCHS, imsave_freq):
    # hoge0000.png, hoge0001.png,..., hoge0090.png
    img = cv2.imread(f'results/{i}.png')

    # can't read image, escape
    if img is None:
        print("can't read")
        break

    # add
    video.write(img)
    
video.release()
print('written')

written


In [14]:
q = [nn.Variable.from_numpy_array(n) for n in diff]

In [15]:
rnd = np.random.RandomState(67)
z = rnd.randn(batch_size, 512)

style_noises = [nn.NdArray.from_numpy_array(z) for _ in range(2)]

# loading styleGAN parmeters
with nn.parameter_scope('gan'):
    nn.load_parameters(gan_path)
    
    constant = nn.parameter.get_parameter_or_create(
                name="G_synthesis/4x4/Const/const",shape=(1, 512, 4, 4))
    constant_bc = F.broadcast(constant, (batch_size,) + constant.shape[1:])
    dlatent_avg = nn.parameter.get_parameter_or_create(
                name="dlatent_avg", shape=(1, 512))

    style_noises_normalized = []
    for style_noise in style_noises:
        noise_std = (F.mean(style_noise ** 2., axis=1,
                            keepdims=True)+1e-8) ** 0.5
        style_noise_normalized = F.div2(style_noise, noise_std)
        style_noises_normalized.append(style_noise_normalized)

    new_w = [mapping_network(_, outmaps=512) for _ in style_noises_normalized]

In [16]:
new_w

[<NdArray((1, 512)) at 0x2b931fbb96f0>, <NdArray((1, 512)) at 0x2b931fbb9690>]

In [17]:
new_latent = []
new_latent.append(new_w[0] - q[0])
new_latent.append(new_w[1] - q[1])

In [18]:
new_latent

[<NdArray((1, 512)) at 0x2b931fbb97b0>, <NdArray((1, 512)) at 0x2b931fbb98a0>]

In [19]:
with nn.parameter_scope('gan'):
    # new latent space variable
    w = [lerp(dlatent_avg, _, truncation_psi) for _ in new_latent]
    new_rgb_output = synthesis(w, constant_bc, 1, 7, resolution=resolution)

In [20]:
new_rgb_output

<NdArray((1, 3, 1024, 1024)) at 0x2b931fbb9d80>

In [21]:
img = convert_images_to_uint8(new_rgb_output, drange=[-1, 1])
imsave('res_6.png', img[0], channel_first=True)