Skip to content
This repository has been archived by the owner on Apr 2, 2024. It is now read-only.

Commit

Permalink
Add embedding.py
Browse files Browse the repository at this point in the history
  • Loading branch information
sugyan committed Feb 11, 2020
1 parent 94fae58 commit 16d86f0
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 1 deletion.
86 changes: 86 additions & 0 deletions embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import argparse
import tensorflow as tf


class LatentSpace(tf.keras.layers.Layer):
def __init__(self):
super().__init__()
self._variables = self.add_weight(shape=(1, 14, 512), dtype=tf.float32)

def call(self, inputs):
return 1.0 * self._variables


class Synthesis(tf.keras.layers.Layer):
def __init__(self, model_path):
super().__init__()
model = tf.saved_model.load(model_path)
self.synthesis = model.signatures['synthesis']

def call(self, inputs):
return self.synthesis(dlatents=inputs)['outputs']


class GenerateLoss(tf.keras.losses.Loss):
def __init__(self, image):
super().__init__()
self.vgg16 = tf.keras.applications.VGG16(include_top=False)
self.outputs = []
out = image
for layer in self.vgg16.layers:
out = layer(out)
if layer.name in {'block1_conv1', 'block1_conv2', 'block3_conv2', 'block4_conv2'}:
self.outputs.append(out)

def call(self, y_true, y_pred):
outputs = []
out = y_pred
for layer in self.vgg16.layers:
out = layer(out)
if layer.name in {'block1_conv1', 'block1_conv2', 'block3_conv2', 'block4_conv2'}:
outputs.append(out)
n = tf.cast(tf.reduce_prod(y_pred.shape), tf.float32)
losses = tf.math.reduce_sum(tf.keras.losses.MSE(y_true, y_pred)) / n
for i, out in enumerate(outputs):
n = tf.cast(tf.math.reduce_prod(out.shape), tf.float32)
losses += tf.math.reduce_sum(tf.keras.losses.MSE(self.outputs[i], out)) / n
return losses


class GenerateCallback(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs):
v = self.model.layers[0].variables[0].numpy()
images = self.model.layers[1](v)
images = tf.saturate_cast((images + 1.0) * 127.5, tf.uint8)
with open(f'epoch{epoch:03d}.png', 'wb') as fp:
data = tf.image.encode_png(tf.squeeze(images, axis=0)).numpy()
fp.write(data)


def run(model_path, target_image):
with open(target_image, 'rb') as fp:
y = tf.image.decode_jpeg(fp.read())
y = tf.expand_dims(tf.cast(y, tf.float32) / 127.5 - 1.0, axis=0)

model = tf.keras.Sequential([
LatentSpace(),
Synthesis(model_path),
])
model(tf.zeros([]))
model.summary()
model.compile(loss=GenerateLoss(y))
dataset = tf.data.Dataset.from_tensors(([], y))
model.fit(
dataset.repeat().batch(1),
steps_per_epoch=10,
epochs=20,
callbacks=[GenerateCallback()])


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('model_path', type=str)
parser.add_argument('target_image', type=str)
args = parser.parse_args()

run(args.model_path, args.target_image)
3 changes: 2 additions & 1 deletion pkl2savedmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ def convert(network_pkl, save_dir):
{'dlatents': tf.saved_model.utils.build_tensor_info(outputs[1])}),
'synthesis': tf.compat.v1.saved_model.build_signature_def(
{'dlatents': tf.saved_model.utils.build_tensor_info(outputs[1])},
{'images': tf.saved_model.utils.build_tensor_info(images)})
{'images': tf.saved_model.utils.build_tensor_info(images),
'outputs': tf.saved_model.utils.build_tensor_info(tf.transpose(outputs[0], [0, 2, 3, 1]))})
}
builder.add_meta_graph_and_variables(
sess,
Expand Down

0 comments on commit 16d86f0

Please sign in to comment.